Skip to content

Commit

Permalink
Merge pull request #52 from aai-institute/feature/trainer
Browse files Browse the repository at this point in the history
Feature/Cleanup: Trainer class
  • Loading branch information
samuelburbulla authored Feb 22, 2024
2 parents 74f1d5b + d3acdda commit c591c18
Show file tree
Hide file tree
Showing 14 changed files with 695 additions and 165 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## 0.1

- Add `Trainer` class to replace `operator.fit` method.


## 0.0.0 (2024-02-22)

- Set up project structure.
Expand Down
18 changes: 9 additions & 9 deletions examples/basics.ipynb

Large diffs are not rendered by default.

471 changes: 464 additions & 7 deletions examples/meshes.ipynb

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions examples/physicsinformed.ipynb

Large diffs are not rendered by default.

38 changes: 19 additions & 19 deletions examples/selfsupervised.ipynb

Large diffs are not rendered by default.

23 changes: 11 additions & 12 deletions examples/superresolution.ipynb

Large diffs are not rendered by default.

95 changes: 0 additions & 95 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

import torch
from abc import abstractmethod
from time import time
from typing import Optional, List
from continuity.callbacks import Callback, PrintTrainingLoss
from continuity.operators.losses import Loss, MSELoss


class Operator(torch.nn.Module):
Expand All @@ -31,94 +27,3 @@ def forward(
Returns:
Evaluations of the mapped function with shape (batch_size, y_size, output_channels)
"""

def compile(
self,
optimizer: torch.optim.Optimizer,
loss_fn: Optional[Loss] = None,
verbose: bool = True,
):
"""Compile operator.
Args:
optimizer: Torch-like optimizer.
loss_fn: Loss function taking (x, u, y, v). Defaults to MSELoss.
verbose: Print number of model parameters to stdout.
"""
self.optimizer = optimizer
self.loss_fn = loss_fn or MSELoss()

# Print number of model parameters
if verbose:
num_params = sum(p.numel() for p in self.parameters())
print(f"Model parameters: {num_params}")

def fit(
self,
data_loader: torch.utils.data.DataLoader,
epochs: int,
callbacks: Optional[List[Callback]] = None,
):
"""Fit operator to data set.
Args:
dataset: Data set.
epochs: Number of epochs.
callbacks: List of callbacks.
"""
# Default callback
if callbacks is None:
callbacks = [PrintTrainingLoss()]

# Call on_train_begin
for callback in callbacks:
callback.on_train_begin()

# Train
for epoch in range(epochs + 1):
loss_train = 0

start = time()
for x, u, y, v in data_loader:

def closure(x=x, u=u, y=y, v=v):
self.optimizer.zero_grad()
loss = self.loss_fn(self, x, u, y, v)
loss.backward(retain_graph=True)
return loss

self.optimizer.step(closure)
self.optimizer.param_groups[0]["lr"] *= 0.999

# Compute mean loss
loss_train += self.loss_fn(self, x, u, y, v).detach().item()

end = time()
seconds_per_epoch = end - start
loss_train /= len(data_loader)

# Callbacks
logs = {
"loss/train": loss_train,
"seconds_per_epoch": seconds_per_epoch,
}

for callback in callbacks:
callback(epoch, logs)

# Call on_train_end
for callback in callbacks:
callback.on_train_end()

def loss(
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
"""Evaluate loss function.
Args:
x: Tensor of sensor positions of shape (batch_size, num_sensors, coordinate_dim)
u: Tensor of sensor values of shape (batch_size, num_sensors, num_channels)
y: Tensor of coordinates where the mapped function is evaluated of shape (batch_size, x_size, coordinate_dim)
v: Tensor of labels of shape (batch_size, x_size, num_channels)
"""
return self.loss_fn(self, x, u, y, v)
9 changes: 9 additions & 0 deletions src/continuity/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
`continuity.trainer`
Trainer for operator learning.
"""

from .trainer import Trainer

__all__ = ["Trainer"]
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
"""
`continuity.callbacks`
Callbacks for training in Continuity.
"""
"""Callbacks for Trainer in Continuity."""

from abc import ABC, abstractmethod
from typing import Optional, List, Dict
Expand All @@ -11,7 +7,7 @@

class Callback(ABC):
"""
Callback base class for `fit` method of `Operator`.
Callback base class for `fit` method of `Trainer`.
"""

@abstractmethod
Expand Down Expand Up @@ -55,7 +51,7 @@ def __call__(self, epoch: int, logs: Dict[str, float]):

print(
f"\rEpoch {epoch}: loss/train = {loss_train:.4e} "
f"({seconds_per_epoch:.2f} s/epoch)",
f"({seconds_per_epoch:.3g} s/epoch)",
end="",
)

Expand Down
121 changes: 121 additions & 0 deletions src/continuity/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
from time import time
from typing import Optional, List
from continuity.operators import Operator
from continuity.operators.losses import Loss, MSELoss
from continuity.trainer.callbacks import Callback, PrintTrainingLoss


class Trainer:
"""Trainer.
Implements a default training loop for operator learning.
Example:
```python
from continuity.trainer import Trainer
from continuity.operators.losses import MSELoss
...
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-3)
loss_fn = MSELoss()
trainer = Trainer(operator, optimizer, loss_fn, device="cuda:0")
trainer.fit(data_loader, epochs=100)
```
Args:
operator: Operator to be trained.
optimizer: Torch-like optimizer. Default is Adam.
criterion: Loss function taking (op, x, u, y, v). Default is MSELoss.
device: Device to train on. Default is CPU.
"""

def __init__(
self,
operator: Operator,
optimizer: Optional[torch.optim.Optimizer] = None,
loss_fn: Optional[Loss] = None,
device: Optional[torch.device] = None,
verbose: bool = True,
):
self.operator = operator
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.Adam(operator.parameters(), lr=1e-3)
)
self.loss_fn = loss_fn if loss_fn is not None else MSELoss()
self.device = device if device is not None else torch.device("cpu")
self.verbose = verbose

def fit(
self,
data_loader: torch.utils.data.DataLoader,
epochs: int = 100,
callbacks: Optional[List[Callback]] = None,
):
"""Fit operator to data set.
Args:
dataset: Data set.
epochs: Number of epochs.
callbacks: List of callbacks.
"""
# Default callback
if callbacks is None:
if self.verbose:
callbacks = [PrintTrainingLoss()]
else:
callbacks = []

# Print number of model parameters
if self.verbose:
num_params = sum(p.numel() for p in self.operator.parameters())
print(f"Model parameters: {num_params}")

# Move operator to device
self.operator.to(self.device)

# Call on_train_begin
for callback in callbacks:
callback.on_train_begin()

# Train
self.operator.train()
for epoch in range(epochs):
loss_train = 0

start = time()
for x, u, y, v in data_loader:
x, u = x.to(self.device), u.to(self.device)
y, v = y.to(self.device), v.to(self.device)

def closure(x=x, u=u, y=y, v=v):
self.optimizer.zero_grad()
loss = self.loss_fn(self.operator, x, u, y, v)
loss.backward(retain_graph=True)
return loss

self.optimizer.step(closure)

# Compute mean loss
loss_train += self.loss_fn(self.operator, x, u, y, v).detach().item()

end = time()
seconds_per_epoch = end - start
loss_train /= len(data_loader)

# Callbacks
logs = {
"loss/train": loss_train,
"seconds_per_epoch": seconds_per_epoch,
}

for callback in callbacks:
callback(epoch + 1, logs)

# Call on_train_end
for callback in callbacks:
callback.on_train_end()

# Move operator back to CPU
self.operator.to("cpu")
13 changes: 10 additions & 3 deletions tests/operators/test_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from torch.utils.data import DataLoader
from continuity.operators import DeepONet
from continuity.data import OperatorDataset, Sine
from continuity.trainer import Trainer
from continuity.operators.losses import MSELoss


def test_output_shape():
Expand Down Expand Up @@ -62,8 +64,8 @@ def test_deeponet():

# Train self-supervised
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-2)
operator.compile(optimizer)
operator.fit(data_loader, epochs=1000)
trainer = Trainer(operator, optimizer)
trainer.fit(data_loader, epochs=1000)

# Plotting
fig, ax = plt.subplots(1, 1)
Expand All @@ -75,4 +77,9 @@ def test_deeponet():
# Check solution
x = x.unsqueeze(0)
u = u.unsqueeze(0)
assert operator.loss(x, u, x, u) < 1e-3
assert MSELoss()(operator, x, u, x, u) < 1e-3


if __name__ == "__main__":
test_output_shape()
test_deeponet()
8 changes: 5 additions & 3 deletions tests/operators/test_neuraloperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from continuity.operators import NeuralOperator
from continuity.plotting import plot, plot_evaluation
from torch.utils.data import DataLoader
from continuity.trainer import Trainer
from continuity.operators.losses import MSELoss

# Set random seed
torch.manual_seed(0)
Expand All @@ -29,8 +31,8 @@ def test_neuraloperator():

# Train self-supervised
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-2)
operator.compile(optimizer)
operator.fit(data_loader, epochs=400)
trainer = Trainer(operator, optimizer)
trainer.fit(data_loader, epochs=400)

# Plotting
fig, ax = plt.subplots(1, 1)
Expand All @@ -42,7 +44,7 @@ def test_neuraloperator():
# Check solution
x = x.unsqueeze(0)
u = u.unsqueeze(0)
assert operator.loss(x, u, x, u) < 1e-5
assert MSELoss()(operator, x, u, x, u) < 1e-3


if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions tests/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch
from torch.utils.data import DataLoader
from continuity.benchmarks.sine import SineBenchmark
from continuity.callbacks import OptunaCallback
from continuity.trainer import Trainer
from continuity.trainer.callbacks import OptunaCallback
from continuity.data import split, dataset_loss
from continuity.operators import DeepONet
import optuna
Expand Down Expand Up @@ -36,8 +37,8 @@ def objective(trial):
# Optimizer
optimizer = torch.optim.Adam(operator.parameters(), lr=lr)

operator.compile(optimizer, verbose=False)
operator.fit(train_loader, epochs=num_epochs, callbacks=[OptunaCallback(trial)])
trainer = Trainer(operator, optimizer, verbose=False)
trainer.fit(train_loader, epochs=num_epochs, callbacks=[OptunaCallback(trial)])

loss_val = dataset_loss(val_dataset, operator, benchmark.metric())
print(f"loss/val: {loss_val:.4e}")
Expand All @@ -52,7 +53,7 @@ def objective(trial):
storage=f"sqlite:///{name}.db",
load_if_exists=True,
)
study.optimize(objective, n_trials=10)
study.optimize(objective, n_trials=3)


if __name__ == "__main__":
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch.utils.data import DataLoader
from continuity.operators import DeepONet
from continuity.data import device, Sine
from continuity.trainer import Trainer

torch.manual_seed(0)


def test_trainer():
dataset = Sine(num_sensors=32, size=16)
data_loader = DataLoader(dataset)
operator = DeepONet(dataset.shapes)

print(f"Using device: {device}")
trainer = Trainer(operator, device=device)
trainer.fit(data_loader, epochs=2)

# Make sure we can use operator output on cpu again
x, u, y, v = next(iter(data_loader))
v_pred = operator(x, u, y)
mse = ((v_pred - v.to("cpu")) ** 2).mean()
print(f"mse = {mse.item():.3g}")


if __name__ == "__main__":
test_trainer()

0 comments on commit c591c18

Please sign in to comment.