Skip to content

Commit

Permalink
Merge pull request #30 from aai-institute/feature/optuna
Browse files Browse the repository at this point in the history
Feature: Optuna
  • Loading branch information
samuelburbulla authored Feb 9, 2024
2 parents dd39745 + 6797726 commit ef2daa0
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ dmypy.json

# Temporary
runs/
*.db

# Docs
docs_build
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ install_requires =
dadaptation==3.1
matplotlib
pandas
optuna==3.5.0


[options.packages.find]
Expand Down
24 changes: 24 additions & 0 deletions src/continuity/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Benchmarks for operator learning.
"""

from abc import ABC, abstractmethod


class Benchmark(ABC):
"""Benchmark base class."""

@abstractmethod
def train_dataset(self):
"""Return training data set."""
raise NotImplementedError

@abstractmethod
def test_dataset(self):
"""Return test data set."""
raise NotImplementedError

@abstractmethod
def metric(self):
"""Return metric."""
raise NotImplementedError
39 changes: 39 additions & 0 deletions src/continuity/benchmarks/sine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Sine benchmark."""

from continuity.benchmarks import Benchmark
from continuity.data import DataSet, split
from continuity.data.datasets import Sine
from continuity.operators.losses import Loss, MSELoss


class SineBenchmark(Benchmark):
"""Sine benchmark."""

def __init__(self):
self.num_sensors = 32
self.size = 100
self.batch_size = 1

self.dataset = Sine(
num_sensors=32,
size=100,
batch_size=1,
)

self.train_dataset, self.test_dataset = split(self.dataset, 0.9)

def dataset(self) -> DataSet:
"""Return data set."""
return self.dataset

def train_dataset(self) -> DataSet:
"""Return training data set."""
return self.train_dataset

def test_dataset(self) -> DataSet:
"""Return test data set."""
return self.test_dataset

def metric(self) -> Loss:
"""Return metric."""
return MSELoss()
29 changes: 29 additions & 0 deletions src/continuity/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,32 @@ def on_train_end(self):
plt.ylabel("Loss")
plt.legend(self.keys)
plt.show()


class OptunaCallback(Callback):
"""
Callback to report intermediate values to Optuna.
Args:
trial: Optuna trial.
"""

def __init__(self, trial):
self.trial = trial
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
"""
self.trial.report(logs["loss/train"], step=epoch)

def on_train_begin(self):
"""Called at the beginning of training."""

def on_train_end(self):
"""Called at the end of training."""
39 changes: 39 additions & 0 deletions src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,45 @@ def tensor(x):
return torch.tensor(x, device=device, dtype=torch.float32)


def split(dataset, split=0.5, seed=None):
"""
Split data set into two parts.
Args:
split: Split fraction.
"""
assert 0 < split < 1, "Split fraction must be between 0 and 1."

generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)

return torch.utils.data.random_split(
dataset,
[split, 1 - split],
generator=generator,
)


def dataset_loss(dataset, operator, loss_fn):
"""Evaluate operator performance on data set.
Args:
dataset: Data set.
operator: Operator.
loss_fn: Loss function.
"""
loss = 0.0
n = len(dataset)

for i in range(n):
x, u, y, v = dataset[i]
batch_size = x.shape[0]
loss += loss_fn(operator, x, u, y, v) / batch_size

return loss


class DataSet:
"""Data set base class.
Expand Down
12 changes: 9 additions & 3 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def forward(self, x: Tensor, u: Tensor, y: Tensor) -> Tensor:
Tensor of evaluations of the mapped function of shape (batch_size, y_size, output_channels)
"""

def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = None):
def compile(
self,
optimizer: torch.optim.Optimizer,
loss_fn: Optional[Loss] = None,
verbose: bool = True,
):
"""Compile operator.
Args:
Expand All @@ -46,8 +51,9 @@ def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = No
self.to(device)

# Print number of model parameters
num_params = sum(p.numel() for p in self.parameters())
print(f"Model parameters: {num_params} Device: {device}")
if verbose:
num_params = sum(p.numel() for p in self.parameters())
print(f"Model parameters: {num_params} Device: {device}")

def fit(
self,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_optuna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from continuity.benchmarks.sine import SineBenchmark
from continuity.callbacks import OptunaCallback
from continuity.data import split, dataset_loss
from continuity.operators import DeepONet
import optuna

# Set random seed
torch.manual_seed(0)


def test_optuna():
def objective(trial):
trunk_width = trial.suggest_int("trunk_width", 4, 16)
trunk_depth = trial.suggest_int("trunk_depth", 4, 16)
num_epochs = trial.suggest_int("num_epochs", 1, 10)
lr = trial.suggest_float("lr", 1e-4, 1e-3)

# Data set
benchmark = SineBenchmark()

# Train/val split
train_dataset, val_dataset = split(benchmark.train_dataset, 0.9)

# Operator
operator = DeepONet(
benchmark.dataset.num_sensors,
benchmark.dataset.coordinate_dim,
benchmark.dataset.num_channels,
trunk_width=trunk_width,
trunk_depth=trunk_depth,
)

# Optimizer
optimizer = torch.optim.Adam(operator.parameters(), lr=lr)

operator.compile(optimizer, verbose=False)
operator.fit(
train_dataset, epochs=num_epochs, callbacks=[OptunaCallback(trial)]
)

loss_val = dataset_loss(val_dataset, operator, benchmark.metric())
print(f"loss/val: {loss_val:.4e}")

return loss_val

# Run hyperparameter optimization
name = "test_optuna"
study = optuna.create_study(
direction="minimize",
study_name=name,
storage=f"sqlite:///{name}.db",
load_if_exists=True,
)
study.optimize(objective, n_trials=10)


if __name__ == "__main__":
test_optuna()

0 comments on commit ef2daa0

Please sign in to comment.