Skip to content

Commit

Permalink
join dataset test implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner committed Feb 13, 2024
1 parent 1d766e1 commit b0b6141
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 51 deletions.
3 changes: 1 addition & 2 deletions src/continuity/benchmarks/sine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Sine benchmark."""

from continuity.benchmarks import Benchmark
from continuity.data import split
from continuity.data.datasets import Sine
from continuity.data import Sine, split
from continuity.operators.losses import Loss, MSELoss
from torch.utils.data import Dataset

Expand Down
7 changes: 4 additions & 3 deletions src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"DatasetShape",
"Sine",
"Flame",
"device"
"device",
"split"
]


Expand Down Expand Up @@ -79,7 +80,7 @@ def dataset_loss(dataset, operator, loss_fn):
loss = 0.0

for x, u, y, v in dataset:
batch_size = x.shape[0]
loss += loss_fn(operator, x, u, y, v) / batch_size
x, u, y, v = x.unsqueeze(0), u.unsqueeze(0), y.unsqueeze(0), v.unsqueeze(0)
loss += loss_fn(operator, x, u, y, v)

return loss
8 changes: 3 additions & 5 deletions src/continuity/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import torch
import numpy as np
from torch import Tensor
from typing import Optional
from matplotlib.axis import Axis
import matplotlib.pyplot as plt
from continuity.operators import Operator


def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):
def plot(x: torch.Tensor, u: torch.Tensor, ax: Optional[Axis] = None):
"""Plots a function $u(x)$.
Currently only supports coordinate dimensions of $d = 1,2$.
Expand Down Expand Up @@ -44,7 +43,7 @@ def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):


def plot_evaluation(
operator: Operator, x: Tensor, u: Tensor, ax: Optional[Axis] = None
operator: Operator, x: torch.Tensor, u: torch.Tensor, ax: Optional[Axis] = None
):
"""Plots the mapped function `operator(observation)` evaluated on a $[-1, 1]^d$ grid.
Expand All @@ -63,8 +62,7 @@ def plot_evaluation(
assert dim in [1, 2], "Only supports `d = 1,2`"

if dim == 1:
n = 200
y = torch.linspace(-1, 1, n).unsqueeze(-1)
y = torch.linspace(-1, 1, 200).unsqueeze(-1)
x = x.unsqueeze(0)
u = u.unsqueeze(0)
y = y.unsqueeze(0)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from continuity.data.datasets import Sine
from continuity.data import Sine, device
from continuity.operators import ContinuousConvolution
from continuity.plotting import plot

# Set random seed
torch.manual_seed(0)


def test_convolution():
torch.set_default_dtype(torch.float64)
# Parameters
num_sensors = 16
num_evals = num_sensors

# Data set
dataset = Sine(num_sensors, size=1)
x, u = dataset.x[0], dataset.u[0]
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
x, u, _, _ = next(iter(dataloader))

# Kernel
def dirac(x, y):
dist = ((x - y) ** 2).sum(dim=-1)
zero = torch.zeros(1)
return torch.isclose(dist, zero).to(torch.float32)
return torch.isclose(dist, zero).to(torch.float64)

# Operator
operator = ContinuousConvolution(
Expand All @@ -31,18 +33,16 @@ def dirac(x, y):
)

# Create tensors
y = torch.linspace(-1, 1, num_evals).unsqueeze(-1)
y = torch.linspace(-1, 1, num_evals).reshape(1, -1, 1).to(device)

# Apply operator
v = operator(x.reshape((1, -1, 1)), u.reshape((1, -1, 1)), y.reshape((1, -1, 1)))

# Extract batch
v = v.squeeze(0)

# Plotting
fig, ax = plt.subplots(1, 1)
plot(x, u, ax=ax)
plot(x, v, ax=ax)
x_plot = x[0].squeeze().detach().numpy()
ax.plot(x_plot, u[0].squeeze().detach().numpy(), "x-")
ax.plot(x_plot, v[0].squeeze().detach().numpy(), "--")
fig.savefig(f"test_convolution.png")

# For num_sensors == num_evals, we get v = u / num_sensors.
Expand Down
20 changes: 7 additions & 13 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from continuity.data import tensor
from continuity.data.datasets import SelfSupervisedDataSet
from continuity.data import SelfSupervisedOperatorDataset
from continuity.plotting import plot
from torch.utils.data import DataLoader

# Set random seed
torch.manual_seed(0)
Expand All @@ -12,11 +11,11 @@
def test_dataset():
# Sensors
num_sensors = 4
f = lambda x: x**2
f = lambda x: x ** 2
num_channels = 1
coordinate_dim = 1

x = tensor(range(num_sensors)).reshape(-1, 1)
x = torch.Tensor(range(num_sensors)).reshape(-1, 1)
u = f(x)

# Test plotting
Expand All @@ -25,17 +24,12 @@ def test_dataset():
fig.savefig(f"test_dataset.png")

# Dataset
dataset = SelfSupervisedDataSet(
x.unsqueeze(0),
u.unsqueeze(0),
)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)
dataset = SelfSupervisedOperatorDataset(x.unsqueeze(0), u.unsqueeze(0))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
x_target, u_target = x, u

# Test
for sample in dataloader:
x, u, y, v = sample

for x, u, y, v in dataloader:
# Every x, u must equal observation
assert x.shape[1] == num_sensors
assert u.shape[1] == num_sensors
Expand Down
9 changes: 4 additions & 5 deletions tests/test_deeponet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import matplotlib.pyplot as plt
from continuity.data.datasets import Sine

from continuity.data import Sine
from continuity.operators import DeepONet
from continuity.plotting import plot, plot_evaluation
from torch.utils.data import DataLoader
Expand All @@ -19,9 +20,7 @@ def test_deeponet():

# Operator
operator = DeepONet(
num_sensors,
dataset.coordinate_dim,
dataset.num_channels,
dataset.shape,
branch_width=32,
branch_depth=1,
trunk_width=32,
Expand All @@ -32,7 +31,7 @@ def test_deeponet():
# Train self-supervised
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-2)
operator.compile(optimizer)
operator.fit(dataloader, epochs=1000)
operator.fit(dataset, batch_size=1, epochs=1000)

# Plotting
fig, ax = plt.subplots(1, 1)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_neuraloperator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import matplotlib.pyplot as plt
from continuity.data.datasets import Sine
from continuity.data import Sine
from continuity.operators import NeuralOperator
from continuity.plotting import plot, plot_evaluation
from torch.utils.data import DataLoader
Expand All @@ -19,8 +19,7 @@ def test_neuraloperator():

# Operator
operator = NeuralOperator(
coordinate_dim=dataset.coordinate_dim,
num_channels=dataset.num_channels,
dataset_shape=dataset.shape,
depth=1,
kernel_width=32,
kernel_depth=3,
Expand All @@ -29,7 +28,7 @@ def test_neuraloperator():
# Train self-supervised
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-2)
operator.compile(optimizer)
operator.fit(dataloader, epochs=400)
operator.fit(dataset, batch_size=1, epochs=400)

# Plotting
fig, ax = plt.subplots(1, 1)
Expand Down
12 changes: 3 additions & 9 deletions tests/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@ def objective(trial):

# Train/val split
train_dataset, val_dataset = split(benchmark.train_dataset, 0.9)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)

# Operator
operator = DeepONet(
benchmark.dataset.num_sensors,
benchmark.dataset.coordinate_dim,
benchmark.dataset.num_channels,
benchmark.dataset.shape,
trunk_width=trunk_width,
trunk_depth=trunk_depth,
)
Expand All @@ -38,11 +34,9 @@ def objective(trial):
optimizer = torch.optim.Adam(operator.parameters(), lr=lr)

operator.compile(optimizer, verbose=False)
operator.fit(
train_dataloader, epochs=num_epochs, callbacks=[OptunaCallback(trial)]
)
operator.fit(train_dataset, epochs=num_epochs, callbacks=[OptunaCallback(trial)])

loss_val = dataset_loss(val_dataloader, operator, benchmark.metric())
loss_val = dataset_loss(val_dataset, operator, benchmark.metric())
print(f"loss/val: {loss_val:.4e}")

return loss_val
Expand Down

0 comments on commit b0b6141

Please sign in to comment.