Skip to content

Commit

Permalink
Merge pull request #27 from aai-institute/feature/callbacks
Browse files Browse the repository at this point in the history
Add callbacks.
  • Loading branch information
samuelburbulla committed Feb 6, 2024
2 parents ec3a761 + 9a23765 commit dd39745
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 84 deletions.
202 changes: 142 additions & 60 deletions notebooks/superresolution.ipynb

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions src/continuity/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
`continuity.callbacks`
Callbacks for training in Continuity.
"""

from abc import ABC, abstractmethod
from typing import Optional, List, Dict
import matplotlib.pyplot as plt


class Callback(ABC):
"""
Callback base class for `fit` method of `Operator`.
"""

@abstractmethod
def __call__(self, epoch, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
raise NotImplementedError

@abstractmethod
def on_train_begin(self):
"""Called at the beginning of training."""

@abstractmethod
def on_train_end(self):
"""Called at the end of training."""


class PrintTrainingLoss(Callback):
"""
Callback to print training loss.
"""

def __init__(self):
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
loss_train = logs["loss/train"]
seconds_per_epoch = logs["seconds_per_epoch"]

print(
f"\rEpoch {epoch}: loss/train = {loss_train:.4e} "
f"({seconds_per_epoch:.2f} s/epoch)",
end="",
)

def on_train_begin(self):
"""Called at the beginning of training."""

def on_train_end(self):
"""Called at the end of training."""
print("")


class LearningCurve(Callback):
"""
Callback to plot learning curve.
Args:
keys: List of keys to plot. Default is ["loss/train"].
"""

def __init__(self, keys: Optional[List[str]] = None):
if keys is None:
keys = ["loss/train"]

self.keys = keys
self.on_train_begin()
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
for key in self.keys:
if key in logs:
self.losses[key].append(logs[key])

def on_train_begin(self):
"""Called at the beginning of training."""
self.losses = {key: [] for key in self.keys}

def on_train_end(self):
"""Called at the end of training."""
for key in self.keys:
vals = self.losses[key]
epochs = list(range(1, len(vals) + 1))
plt.plot(epochs, vals)

plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(self.keys)
plt.show()
2 changes: 1 addition & 1 deletion src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_device() -> torch.device:
Device.
"""
device = torch.device("cpu")
use_mps_backend = os.environ.get("USE_MPS_BACKEND", True).lower() in ("true", "1")
use_mps_backend = os.environ.get("USE_MPS_BACKEND", "True").lower() in ("true", "1")

if use_mps_backend and torch.backends.mps.is_available():
device = torch.device("mps")
Expand Down
53 changes: 34 additions & 19 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from abc import abstractmethod
from time import time
from typing import Optional
from typing import Optional, List
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from continuity.data import device, DataSet
from continuity.callbacks import Callback, PrintTrainingLoss
from continuity.operators.losses import Loss, MSELoss


Expand Down Expand Up @@ -47,20 +47,32 @@ def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = No

# Print number of model parameters
num_params = sum(p.numel() for p in self.parameters())
print(f"Model parameters: {num_params}")
print(f"Model parameters: {num_params} Device: {device}")

def fit(
self, dataset: DataSet, epochs: int, writer: Optional[SummaryWriter] = None
self,
dataset: DataSet,
epochs: int,
callbacks: Optional[List[Callback]] = None,
):
"""Fit operator to data set.
Args:
dataset: Data set.
epochs: Number of epochs.
writer: Tensorboard-like writer for loss visualization.
callbacks: List of callbacks.
"""
# Default callback
if callbacks is None:
callbacks = [PrintTrainingLoss()]

# Call on_train_begin
for callback in callbacks:
callback.on_train_begin()

# Train
for epoch in range(epochs + 1):
mean_loss = 0
loss_train = 0

start = time()
for i in range(len(dataset)):
Expand All @@ -76,21 +88,24 @@ def closure(x=x, u=u, y=y, v=v):
self.optimizer.param_groups[0]["lr"] *= 0.999

# Compute mean loss
mean_loss += self.loss_fn(self, x, u, y, v).detach().item()
loss_train += self.loss_fn(self, x, u, y, v).detach().item()

end = time()
mean_loss /= len(dataset)

if writer is not None:
writer.add_scalar("Loss/train", mean_loss, epoch)

iter_per_second = len(dataset) / (end - start)
print(
f"\rEpoch {epoch}: loss = {mean_loss:.4e} "
f"({iter_per_second:.2f} it/s)",
end="",
)
print("")
seconds_per_epoch = end - start
loss_train /= len(dataset)

# Callbacks
logs = {
"loss/train": loss_train,
"seconds_per_epoch": seconds_per_epoch,
}

for callback in callbacks:
callback(epoch, logs)

# Call on_train_end
for callback in callbacks:
callback.on_train_end()

def loss(self, x: Tensor, u: Tensor, y: Tensor, v: Tensor) -> Tensor:
"""Evaluate loss function.
Expand Down
6 changes: 5 additions & 1 deletion src/continuity/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):
dim = x.shape[-1]
assert dim in [1, 2], "Only supports `d = 1,2`"

# Move to cpu
x = x.cpu().detach().numpy()
u = u.cpu().detach().numpy()

if dim == 1:
ax.plot(x, u, "k.")
ax.plot(x, u, ".")

if dim == 2:
xx, yy = x[:, 0], x[:, 1]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import matplotlib.pyplot as plt
from continuity.data import device
from continuity.data.datasets import Sine
from continuity.operators import ContinuousConvolution
from continuity.plotting import plot
Expand All @@ -20,7 +21,8 @@ def test_convolution():
# Kernel
def dirac(x, y):
dist = ((x - y) ** 2).sum(dim=-1)
return torch.isclose(dist, torch.zeros(1)).to(torch.float32)
zero = torch.zeros(1, device=device)
return torch.isclose(dist, zero).to(torch.float32)

# Operator
operator = ContinuousConvolution(
Expand All @@ -30,7 +32,7 @@ def dirac(x, y):
)

# Create tensors
y = torch.linspace(-1, 1, num_evals).unsqueeze(-1)
y = torch.linspace(-1, 1, num_evals).unsqueeze(-1).to(device)

# Apply operator
v = operator(x, u, y)
Expand All @@ -41,7 +43,7 @@ def dirac(x, y):
# Plotting
fig, ax = plt.subplots(1, 1)
plot(x, u, ax=ax)
plt.plot(x, v, "o")
plot(x, v, ax=ax)
fig.savefig(f"test_convolution.png")

# For num_sensors == num_evals, we get v = u / num_sensors.
Expand Down

0 comments on commit dd39745

Please sign in to comment.