Skip to content

Commit

Permalink
Add callbacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Feb 6, 2024
1 parent 74991d4 commit d59ca7a
Show file tree
Hide file tree
Showing 3 changed files with 10,065 additions and 74 deletions.
10,010 changes: 9,951 additions & 59 deletions notebooks/superresolution.ipynb

Large diffs are not rendered by default.

91 changes: 91 additions & 0 deletions src/continuity/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
`continuity.callbacks`
Callbacks for training in Continuity.
"""

import math
from abc import ABC, abstractmethod


class Callback(ABC):
"""
Callback base class for `fit` method of `Operator`.
"""

@abstractmethod
def __call__(self, epoch, logs=None):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
raise NotImplementedError


class PrintTrainingLoss(Callback):
"""
Callback to print training loss.
"""

def __init__(self):
super().__init__()

def __call__(self, epoch, logs=None):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
loss_train = logs["loss/train"]
iter_per_second = logs["iter_per_second"]

print(
f"\rEpoch {epoch}: loss/train = {loss_train:.4e} "
f"({iter_per_second:.2f} it/s)",
end="",
)


class LearningCurve(Callback):
"""
Callback to plot learning curve.
"""

def __init__(self):
# Try to import lrcurve
from lrcurve import PlotLearningCurve

self.plot = PlotLearningCurve(
line_config={
"train": {"name": "Train", "color": "#000000"},
},
facet_config={"loss": {"name": "log(loss)", "limit": [None, None]}},
xaxis_config={"name": "Epoch", "limit": [0, None]},
)

super().__init__()

def __call__(self, epoch, logs=None):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
vals = {"loss": {}}

# Collect loss values
for key in ["train", "val"]:
loss_key = "loss/" + key
if loss_key in logs:
log_loss = math.log(logs[loss_key], 10)
vals["loss"][key] = log_loss

self.plot.append(epoch, vals)
self.plot.draw()
38 changes: 23 additions & 15 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from abc import abstractmethod
from time import time
from typing import Optional
from typing import Optional, List
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from continuity.data import device, DataSet
from continuity.callbacks import Callback, PrintTrainingLoss
from continuity.operators.losses import Loss, MSELoss


Expand Down Expand Up @@ -50,17 +50,24 @@ def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = No
print(f"Model parameters: {num_params}")

def fit(
self, dataset: DataSet, epochs: int, writer: Optional[SummaryWriter] = None
self,
dataset: DataSet,
epochs: int,
callbacks: Optional[List[Callback]] = None,
):
"""Fit operator to data set.
Args:
dataset: Data set.
epochs: Number of epochs.
writer: Tensorboard-like writer for loss visualization.
callbacks: List of callbacks.
"""
# Default callback
if callbacks is None:
callbacks = [PrintTrainingLoss()]

for epoch in range(epochs + 1):
mean_loss = 0
loss_train = 0

start = time()
for i in range(len(dataset)):
Expand All @@ -76,20 +83,21 @@ def closure(x=x, u=u, y=y, v=v):
self.optimizer.param_groups[0]["lr"] *= 0.999

# Compute mean loss
mean_loss += self.loss_fn(self, x, u, y, v).detach().item()
loss_train += self.loss_fn(self, x, u, y, v).detach().item()

end = time()
mean_loss /= len(dataset)
iter_per_second = len(dataset) / (end - start)
loss_train /= len(dataset)

if writer is not None:
writer.add_scalar("Loss/train", mean_loss, epoch)
# Callbacks
logs = {
"loss/train": loss_train,
"iter_per_second": iter_per_second,
}

for callback in callbacks:
callback(epoch, logs)

iter_per_second = len(dataset) / (end - start)
print(
f"\rEpoch {epoch}: loss = {mean_loss:.4e} "
f"({iter_per_second:.2f} it/s)",
end="",
)
print("")

def loss(self, x: Tensor, u: Tensor, y: Tensor, v: Tensor) -> Tensor:
Expand Down

0 comments on commit d59ca7a

Please sign in to comment.