From 7948c87bc92c2a90a7a628f24ecee77eeb103817 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Wed, 6 Dec 2023 08:58:46 +0200 Subject: [PATCH] Refactor model submodule. --- notebooks/sine.ipynb | 2 +- src/continuity/model/deeponet.py | 74 ---------- src/continuity/model/fnn.py | 63 -------- .../model/{torchmodel.py => model.py} | 4 +- .../model/{residual.py => networks.py} | 0 .../model/{neuraloperator.py => operators.py} | 137 +++++++++++++++++- 6 files changed, 137 insertions(+), 143 deletions(-) delete mode 100644 src/continuity/model/deeponet.py delete mode 100644 src/continuity/model/fnn.py rename src/continuity/model/{torchmodel.py => model.py} (95%) rename src/continuity/model/{residual.py => networks.py} (100%) rename src/continuity/model/{neuraloperator.py => operators.py} (53%) diff --git a/notebooks/sine.ipynb b/notebooks/sine.ipynb index 66653dc0..9aaed8a3 100644 --- a/notebooks/sine.ipynb +++ b/notebooks/sine.ipynb @@ -29,7 +29,7 @@ "import matplotlib.pyplot as plt\n", "from continuity.data.sine import SineWaves\n", "from continuity.plotting.plotting import *\n", - "from continuity.model.neuraloperator import NeuralOperator" + "from continuity.model.operators import NeuralOperator" ] }, { diff --git a/src/continuity/model/deeponet.py b/src/continuity/model/deeponet.py deleted file mode 100644 index b6f89f2a..00000000 --- a/src/continuity/model/deeponet.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from continuity.model.residual import DeepResidualNetwork -from continuity.model.torchmodel import TorchModel - - -class DeepONet(TorchModel): - """DeepONet architecture.""" - - def __init__( - self, - coordinate_dim: int, - num_channels: int, - num_sensors: int, - branch_width: int, - branch_depth: int, - trunk_width: int, - trunk_depth: int, - basis_functions: int, - ): - """A model maps observations to evaluations. - - Args: - coordinate_dim: Dimension of coordinate space - num_channels: Number of channels - num_sensors: Number of input sensors - branch_width: Width of branch network - branch_depth: Depth of branch network - trunk_width: Width of trunk network - trunk_depth: Depth of trunk network - basis_functions: Number of basis functions - """ - super().__init__() - - self.coordinate_dim = coordinate_dim - self.num_channels = num_channels - self.num_sensors = num_sensors - self.basis_functions = basis_functions - - branch_input = num_sensors * (num_channels + coordinate_dim) - trunk_input = coordinate_dim - - self.branch = DeepResidualNetwork( - branch_input, - self.num_channels * basis_functions, - branch_width, - branch_depth, - ) - - self.trunk = DeepResidualNetwork( - trunk_input, - self.num_channels * basis_functions, - trunk_width, - trunk_depth, - ) - - def forward(self, u, x): - """Forward pass.""" - batch_size_u = u.shape[0] - batch_size_x = x.shape[1] - - u = u.reshape((batch_size_u, -1)) - x = x.reshape((-1, self.coordinate_dim)) - - b = self.branch(u) - t = self.trunk(x) - - b = b.reshape((batch_size_u, self.basis_functions, self.num_channels)) - t = t.reshape( - (batch_size_u, batch_size_x, self.basis_functions, self.num_channels) - ) - - sum = torch.einsum("ubc,uxbc->uxc", b, t) - assert sum.shape == (batch_size_u, batch_size_x, self.num_channels) - return sum diff --git a/src/continuity/model/fnn.py b/src/continuity/model/fnn.py deleted file mode 100644 index 6fc28782..00000000 --- a/src/continuity/model/fnn.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from continuity.model import device -from continuity.model.residual import DeepResidualNetwork -from continuity.model.torchmodel import TorchModel - - -class FullyConnected(TorchModel): - """Fully connected architecture.""" - - def __init__( - self, - coordinate_dim: int, - num_channels: int, - num_sensors: int, - width: int, - depth: int, - ): - """Maps observations and positions to evaluations. - - Args: - coordinate_dim: Dimension of coordinate space - num_channels: Number of channels - num_sensors: Number of input sensors - width: Width of network - depth: Depth of network - """ - super().__init__() - - self.coordinate_dim = coordinate_dim - self.num_channels = num_channels - self.num_sensors = num_sensors - self.width = width - self.depth = depth - - self.input_size = num_sensors * (num_channels + coordinate_dim) + coordinate_dim - output_size = num_channels - self.drn = DeepResidualNetwork( - self.input_size, - output_size, - self.width, - self.depth, - ) - - def forward(self, u, x): - """Forward pass.""" - batch_size = u.shape[0] - assert batch_size == x.shape[0] - num_positions = x.shape[1] - d = self.coordinate_dim - - ux = torch.empty( - (batch_size, num_positions, self.input_size), - device=device, - ) - - for i, u_tensor in enumerate(u): - ux[i, :, :-d] = u_tensor.flatten() - ux[:, :, -d:] = x - - ux = ux.reshape((batch_size * num_positions, -1)) - v = self.drn(ux) - v = v.reshape((batch_size, num_positions, self.num_channels)) - return v diff --git a/src/continuity/model/torchmodel.py b/src/continuity/model/model.py similarity index 95% rename from src/continuity/model/torchmodel.py rename to src/continuity/model/model.py index eba8bccf..517c400d 100644 --- a/src/continuity/model/torchmodel.py +++ b/src/continuity/model/model.py @@ -3,8 +3,8 @@ from continuity.model import device -class TorchModel(torch.nn.Module): - """Torch model.""" +class BaseModel(torch.nn.Module): + """Common pyTorch model base.""" def compile(self, optimizer, criterion): """Compile model.""" diff --git a/src/continuity/model/residual.py b/src/continuity/model/networks.py similarity index 100% rename from src/continuity/model/residual.py rename to src/continuity/model/networks.py diff --git a/src/continuity/model/neuraloperator.py b/src/continuity/model/operators.py similarity index 53% rename from src/continuity/model/neuraloperator.py rename to src/continuity/model/operators.py index cdc3b188..d2b5322d 100644 --- a/src/continuity/model/neuraloperator.py +++ b/src/continuity/model/operators.py @@ -1,6 +1,137 @@ import torch -from continuity.model.residual import DeepResidualNetwork -from continuity.model.torchmodel import TorchModel +from continuity.model import device +from continuity.model.networks import DeepResidualNetwork +from continuity.model.model import BaseModel + + +class FullyConnected(BaseModel): + """Fully connected architecture.""" + + def __init__( + self, + coordinate_dim: int, + num_channels: int, + num_sensors: int, + width: int, + depth: int, + ): + """Maps observations and positions to evaluations. + + Args: + coordinate_dim: Dimension of coordinate space + num_channels: Number of channels + num_sensors: Number of input sensors + width: Width of network + depth: Depth of network + """ + super().__init__() + + self.coordinate_dim = coordinate_dim + self.num_channels = num_channels + self.num_sensors = num_sensors + self.width = width + self.depth = depth + + self.input_size = num_sensors * (num_channels + coordinate_dim) + coordinate_dim + output_size = num_channels + self.drn = DeepResidualNetwork( + self.input_size, + output_size, + self.width, + self.depth, + ) + + def forward(self, u, x): + """Forward pass.""" + batch_size = u.shape[0] + assert batch_size == x.shape[0] + num_positions = x.shape[1] + d = self.coordinate_dim + + ux = torch.empty( + (batch_size, num_positions, self.input_size), + device=device, + ) + + for i, u_tensor in enumerate(u): + ux[i, :, :-d] = u_tensor.flatten() + ux[:, :, -d:] = x + + ux = ux.reshape((batch_size * num_positions, -1)) + v = self.drn(ux) + v = v.reshape((batch_size, num_positions, self.num_channels)) + return v + + +class DeepONet(BaseModel): + """DeepONet architecture.""" + + def __init__( + self, + coordinate_dim: int, + num_channels: int, + num_sensors: int, + branch_width: int, + branch_depth: int, + trunk_width: int, + trunk_depth: int, + basis_functions: int, + ): + """A model maps observations to evaluations. + + Args: + coordinate_dim: Dimension of coordinate space + num_channels: Number of channels + num_sensors: Number of input sensors + branch_width: Width of branch network + branch_depth: Depth of branch network + trunk_width: Width of trunk network + trunk_depth: Depth of trunk network + basis_functions: Number of basis functions + """ + super().__init__() + + self.coordinate_dim = coordinate_dim + self.num_channels = num_channels + self.num_sensors = num_sensors + self.basis_functions = basis_functions + + branch_input = num_sensors * (num_channels + coordinate_dim) + trunk_input = coordinate_dim + + self.branch = DeepResidualNetwork( + branch_input, + self.num_channels * basis_functions, + branch_width, + branch_depth, + ) + + self.trunk = DeepResidualNetwork( + trunk_input, + self.num_channels * basis_functions, + trunk_width, + trunk_depth, + ) + + def forward(self, u, x): + """Forward pass.""" + batch_size_u = u.shape[0] + batch_size_x = x.shape[1] + + u = u.reshape((batch_size_u, -1)) + x = x.reshape((-1, self.coordinate_dim)) + + b = self.branch(u) + t = self.trunk(x) + + b = b.reshape((batch_size_u, self.basis_functions, self.num_channels)) + t = t.reshape( + (batch_size_u, batch_size_x, self.basis_functions, self.num_channels) + ) + + sum = torch.einsum("ubc,uxbc->uxc", b, t) + assert sum.shape == (batch_size_u, batch_size_x, self.num_channels) + return sum class ContinuousConvolutionLayer(torch.nn.Module): @@ -64,7 +195,7 @@ def forward(self, yu, x): return integral -class NeuralOperator(TorchModel): +class NeuralOperator(BaseModel): """Neural operator architecture.""" def __init__(