Skip to content

Commit

Permalink
Refactor model submodule.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Dec 6, 2023
1 parent 54a2581 commit 7948c87
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 143 deletions.
2 changes: 1 addition & 1 deletion notebooks/sine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"import matplotlib.pyplot as plt\n",
"from continuity.data.sine import SineWaves\n",
"from continuity.plotting.plotting import *\n",
"from continuity.model.neuraloperator import NeuralOperator"
"from continuity.model.operators import NeuralOperator"
]
},
{
Expand Down
74 changes: 0 additions & 74 deletions src/continuity/model/deeponet.py

This file was deleted.

63 changes: 0 additions & 63 deletions src/continuity/model/fnn.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from continuity.model import device


class TorchModel(torch.nn.Module):
"""Torch model."""
class BaseModel(torch.nn.Module):
"""Common pyTorch model base."""

def compile(self, optimizer, criterion):
"""Compile model."""
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,137 @@
import torch
from continuity.model.residual import DeepResidualNetwork
from continuity.model.torchmodel import TorchModel
from continuity.model import device
from continuity.model.networks import DeepResidualNetwork
from continuity.model.model import BaseModel


class FullyConnected(BaseModel):
"""Fully connected architecture."""

def __init__(
self,
coordinate_dim: int,
num_channels: int,
num_sensors: int,
width: int,
depth: int,
):
"""Maps observations and positions to evaluations.
Args:
coordinate_dim: Dimension of coordinate space
num_channels: Number of channels
num_sensors: Number of input sensors
width: Width of network
depth: Depth of network
"""
super().__init__()

self.coordinate_dim = coordinate_dim
self.num_channels = num_channels
self.num_sensors = num_sensors
self.width = width
self.depth = depth

self.input_size = num_sensors * (num_channels + coordinate_dim) + coordinate_dim
output_size = num_channels
self.drn = DeepResidualNetwork(
self.input_size,
output_size,
self.width,
self.depth,
)

def forward(self, u, x):
"""Forward pass."""
batch_size = u.shape[0]
assert batch_size == x.shape[0]
num_positions = x.shape[1]
d = self.coordinate_dim

ux = torch.empty(
(batch_size, num_positions, self.input_size),
device=device,
)

for i, u_tensor in enumerate(u):
ux[i, :, :-d] = u_tensor.flatten()
ux[:, :, -d:] = x

ux = ux.reshape((batch_size * num_positions, -1))
v = self.drn(ux)
v = v.reshape((batch_size, num_positions, self.num_channels))
return v


class DeepONet(BaseModel):
"""DeepONet architecture."""

def __init__(
self,
coordinate_dim: int,
num_channels: int,
num_sensors: int,
branch_width: int,
branch_depth: int,
trunk_width: int,
trunk_depth: int,
basis_functions: int,
):
"""A model maps observations to evaluations.
Args:
coordinate_dim: Dimension of coordinate space
num_channels: Number of channels
num_sensors: Number of input sensors
branch_width: Width of branch network
branch_depth: Depth of branch network
trunk_width: Width of trunk network
trunk_depth: Depth of trunk network
basis_functions: Number of basis functions
"""
super().__init__()

self.coordinate_dim = coordinate_dim
self.num_channels = num_channels
self.num_sensors = num_sensors
self.basis_functions = basis_functions

branch_input = num_sensors * (num_channels + coordinate_dim)
trunk_input = coordinate_dim

self.branch = DeepResidualNetwork(
branch_input,
self.num_channels * basis_functions,
branch_width,
branch_depth,
)

self.trunk = DeepResidualNetwork(
trunk_input,
self.num_channels * basis_functions,
trunk_width,
trunk_depth,
)

def forward(self, u, x):
"""Forward pass."""
batch_size_u = u.shape[0]
batch_size_x = x.shape[1]

u = u.reshape((batch_size_u, -1))
x = x.reshape((-1, self.coordinate_dim))

b = self.branch(u)
t = self.trunk(x)

b = b.reshape((batch_size_u, self.basis_functions, self.num_channels))
t = t.reshape(
(batch_size_u, batch_size_x, self.basis_functions, self.num_channels)
)

sum = torch.einsum("ubc,uxbc->uxc", b, t)
assert sum.shape == (batch_size_u, batch_size_x, self.num_channels)
return sum


class ContinuousConvolutionLayer(torch.nn.Module):
Expand Down Expand Up @@ -64,7 +195,7 @@ def forward(self, yu, x):
return integral


class NeuralOperator(TorchModel):
class NeuralOperator(BaseModel):
"""Neural operator architecture."""

def __init__(
Expand Down

0 comments on commit 7948c87

Please sign in to comment.