Skip to content

Commit

Permalink
Fix tests with mps backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Feb 6, 2024
1 parent d59ca7a commit 9a23765
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 9,900 deletions.
9,916 changes: 53 additions & 9,863 deletions notebooks/superresolution.ipynb

Large diffs are not rendered by default.

78 changes: 50 additions & 28 deletions src/continuity/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
Callbacks for training in Continuity.
"""

import math
from abc import ABC, abstractmethod
from typing import Optional, List, Dict
import matplotlib.pyplot as plt


class Callback(ABC):
Expand All @@ -14,7 +15,7 @@ class Callback(ABC):
"""

@abstractmethod
def __call__(self, epoch, logs=None):
def __call__(self, epoch, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Expand All @@ -24,6 +25,14 @@ def __call__(self, epoch, logs=None):
"""
raise NotImplementedError

@abstractmethod
def on_train_begin(self):
"""Called at the beginning of training."""

@abstractmethod
def on_train_end(self):
"""Called at the end of training."""


class PrintTrainingLoss(Callback):
"""
Expand All @@ -33,7 +42,7 @@ class PrintTrainingLoss(Callback):
def __init__(self):
super().__init__()

def __call__(self, epoch, logs=None):
def __call__(self, epoch: int, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Expand All @@ -42,50 +51,63 @@ def __call__(self, epoch, logs=None):
logs: Dictionary of logs.
"""
loss_train = logs["loss/train"]
iter_per_second = logs["iter_per_second"]
seconds_per_epoch = logs["seconds_per_epoch"]

print(
f"\rEpoch {epoch}: loss/train = {loss_train:.4e} "
f"({iter_per_second:.2f} it/s)",
f"({seconds_per_epoch:.2f} s/epoch)",
end="",
)

def on_train_begin(self):
"""Called at the beginning of training."""

def on_train_end(self):
"""Called at the end of training."""
print("")


class LearningCurve(Callback):
"""
Callback to plot learning curve.
Args:
keys: List of keys to plot. Default is ["loss/train"].
"""

def __init__(self):
# Try to import lrcurve
from lrcurve import PlotLearningCurve

self.plot = PlotLearningCurve(
line_config={
"train": {"name": "Train", "color": "#000000"},
},
facet_config={"loss": {"name": "log(loss)", "limit": [None, None]}},
xaxis_config={"name": "Epoch", "limit": [0, None]},
)
def __init__(self, keys: Optional[List[str]] = None):
if keys is None:
keys = ["loss/train"]

self.keys = keys
self.on_train_begin()
super().__init__()

def __call__(self, epoch, logs=None):
def __call__(self, epoch: int, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
vals = {"loss": {}}

# Collect loss values
for key in ["train", "val"]:
loss_key = "loss/" + key
if loss_key in logs:
log_loss = math.log(logs[loss_key], 10)
vals["loss"][key] = log_loss

self.plot.append(epoch, vals)
self.plot.draw()
for key in self.keys:
if key in logs:
self.losses[key].append(logs[key])

def on_train_begin(self):
"""Called at the beginning of training."""
self.losses = {key: [] for key in self.keys}

def on_train_end(self):
"""Called at the end of training."""
for key in self.keys:
vals = self.losses[key]
epochs = list(range(1, len(vals) + 1))
plt.plot(epochs, vals)

plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(self.keys)
plt.show()
2 changes: 1 addition & 1 deletion src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_device() -> torch.device:
Device.
"""
device = torch.device("cpu")
use_mps_backend = os.environ.get("USE_MPS_BACKEND", True).lower() in ("true", "1")
use_mps_backend = os.environ.get("USE_MPS_BACKEND", "True").lower() in ("true", "1")

if use_mps_backend and torch.backends.mps.is_available():
device = torch.device("mps")
Expand Down
15 changes: 11 additions & 4 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = No

# Print number of model parameters
num_params = sum(p.numel() for p in self.parameters())
print(f"Model parameters: {num_params}")
print(f"Model parameters: {num_params} Device: {device}")

def fit(
self,
Expand All @@ -66,6 +66,11 @@ def fit(
if callbacks is None:
callbacks = [PrintTrainingLoss()]

# Call on_train_begin
for callback in callbacks:
callback.on_train_begin()

# Train
for epoch in range(epochs + 1):
loss_train = 0

Expand All @@ -86,19 +91,21 @@ def closure(x=x, u=u, y=y, v=v):
loss_train += self.loss_fn(self, x, u, y, v).detach().item()

end = time()
iter_per_second = len(dataset) / (end - start)
seconds_per_epoch = end - start
loss_train /= len(dataset)

# Callbacks
logs = {
"loss/train": loss_train,
"iter_per_second": iter_per_second,
"seconds_per_epoch": seconds_per_epoch,
}

for callback in callbacks:
callback(epoch, logs)

print("")
# Call on_train_end
for callback in callbacks:
callback.on_train_end()

def loss(self, x: Tensor, u: Tensor, y: Tensor, v: Tensor) -> Tensor:
"""Evaluate loss function.
Expand Down
6 changes: 5 additions & 1 deletion src/continuity/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):
dim = x.shape[-1]
assert dim in [1, 2], "Only supports `d = 1,2`"

# Move to cpu
x = x.cpu().detach().numpy()
u = u.cpu().detach().numpy()

if dim == 1:
ax.plot(x, u, "k.")
ax.plot(x, u, ".")

if dim == 2:
xx, yy = x[:, 0], x[:, 1]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import matplotlib.pyplot as plt
from continuity.data import device
from continuity.data.datasets import Sine
from continuity.operators import ContinuousConvolution
from continuity.plotting import plot
Expand All @@ -20,7 +21,8 @@ def test_convolution():
# Kernel
def dirac(x, y):
dist = ((x - y) ** 2).sum(dim=-1)
return torch.isclose(dist, torch.zeros(1)).to(torch.float32)
zero = torch.zeros(1, device=device)
return torch.isclose(dist, zero).to(torch.float32)

# Operator
operator = ContinuousConvolution(
Expand All @@ -30,7 +32,7 @@ def dirac(x, y):
)

# Create tensors
y = torch.linspace(-1, 1, num_evals).unsqueeze(-1)
y = torch.linspace(-1, 1, num_evals).unsqueeze(-1).to(device)

# Apply operator
v = operator(x, u, y)
Expand All @@ -41,7 +43,7 @@ def dirac(x, y):
# Plotting
fig, ax = plt.subplots(1, 1)
plot(x, u, ax=ax)
plt.plot(x, v, "o")
plot(x, v, ax=ax)
fig.savefig(f"test_convolution.png")

# For num_sensors == num_evals, we get v = u / num_sensors.
Expand Down

0 comments on commit 9a23765

Please sign in to comment.