From e284c6e572329b02552eab0b10d1ac1159134b5b Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:20:58 +0100 Subject: [PATCH 1/8] Add benchmark abstract base class. --- src/continuity/benchmarks/__init__.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 src/continuity/benchmarks/__init__.py diff --git a/src/continuity/benchmarks/__init__.py b/src/continuity/benchmarks/__init__.py new file mode 100644 index 00000000..59b27014 --- /dev/null +++ b/src/continuity/benchmarks/__init__.py @@ -0,0 +1,24 @@ +""" +Benchmarks for operator learning. +""" + +from abc import ABC, abstractmethod + + +class Benchmark(ABC): + """Benchmark base class.""" + + @abstractmethod + def train_dataset(self): + """Return training data set.""" + raise NotImplementedError + + @abstractmethod + def test_dataset(self): + """Return test data set.""" + raise NotImplementedError + + @abstractmethod + def metric(self): + """Return metric.""" + raise NotImplementedError From 58a77caaf5f93a266a836ecbe2d0ecf8b6f8a0dc Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:21:11 +0100 Subject: [PATCH 2/8] Add split method. --- src/continuity/data/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/continuity/data/__init__.py b/src/continuity/data/__init__.py index 544f0900..db0e039e 100644 --- a/src/continuity/data/__init__.py +++ b/src/continuity/data/__init__.py @@ -42,6 +42,26 @@ def tensor(x): return torch.tensor(x, device=device, dtype=torch.float32) +def split(dataset, split=0.5, seed=None): + """ + Split data set into two parts. + + Args: + split: Split fraction. + """ + assert 0 < split < 1, "Split fraction must be between 0 and 1." + + generator = torch.Generator() + if seed is not None: + generator.manual_seed(seed) + + return torch.utils.data.random_split( + dataset, + [split, 1 - split], + generator=generator, + ) + + class DataSet: """Data set base class. From afc393bb22b61c127b9c8434830d776277ca810b Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:21:20 +0100 Subject: [PATCH 3/8] Add sine benchmark. --- src/continuity/benchmarks/sine.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/continuity/benchmarks/sine.py diff --git a/src/continuity/benchmarks/sine.py b/src/continuity/benchmarks/sine.py new file mode 100644 index 00000000..4210ac12 --- /dev/null +++ b/src/continuity/benchmarks/sine.py @@ -0,0 +1,39 @@ +"""Sine benchmark.""" + +from continuity.benchmarks import Benchmark +from continuity.data import DataSet, split +from continuity.data.datasets import Sine +from continuity.operators.losses import Loss, MSELoss + + +class SineBenchmark(Benchmark): + """Sine benchmark.""" + + def __init__(self): + self.num_sensors = 32 + self.size = 100 + self.batch_size = 1 + + self.dataset = Sine( + num_sensors=32, + size=100, + batch_size=1, + ) + + self.train_dataset, self.test_dataset = split(self.dataset, 0.9) + + def dataset(self) -> DataSet: + """Return data set.""" + return self.dataset + + def train_dataset(self) -> DataSet: + """Return training data set.""" + return self.train_dataset + + def test_dataset(self) -> DataSet: + """Return test data set.""" + return self.test_dataset + + def metric(self) -> Loss: + """Return metric.""" + return MSELoss() From 2bdfb0d06d1c13134ce8b03197ee751dc3906d51 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:47:16 +0100 Subject: [PATCH 4/8] Add data set loss. --- src/continuity/data/__init__.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/continuity/data/__init__.py b/src/continuity/data/__init__.py index db0e039e..5ab4a146 100644 --- a/src/continuity/data/__init__.py +++ b/src/continuity/data/__init__.py @@ -62,6 +62,25 @@ def split(dataset, split=0.5, seed=None): ) +def dataset_loss(dataset, operator, loss_fn): + """Evaluate operator performance on data set. + + Args: + dataset: Data set. + operator: Operator. + loss_fn: Loss function. + """ + loss = 0.0 + n = len(dataset) + + for i in range(n): + x, u, y, v = dataset[i] + batch_size = x.shape[0] + loss += loss_fn(operator, x, u, y, v) / batch_size + + return loss + + class DataSet: """Data set base class. From c558cc4d2887482067e3425bf6585738cb0f5f8c Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:47:30 +0100 Subject: [PATCH 5/8] Add verbose flag. --- src/continuity/operators/operator.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/continuity/operators/operator.py b/src/continuity/operators/operator.py index 2f0b77f8..4b965459 100644 --- a/src/continuity/operators/operator.py +++ b/src/continuity/operators/operator.py @@ -32,7 +32,12 @@ def forward(self, x: Tensor, u: Tensor, y: Tensor) -> Tensor: Tensor of evaluations of the mapped function of shape (batch_size, y_size, output_channels) """ - def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = None): + def compile( + self, + optimizer: torch.optim.Optimizer, + loss_fn: Optional[Loss] = None, + verbose: bool = True, + ): """Compile operator. Args: @@ -46,8 +51,9 @@ def compile(self, optimizer: torch.optim.Optimizer, loss_fn: Optional[Loss] = No self.to(device) # Print number of model parameters - num_params = sum(p.numel() for p in self.parameters()) - print(f"Model parameters: {num_params} Device: {device}") + if verbose: + num_params = sum(p.numel() for p in self.parameters()) + print(f"Model parameters: {num_params} Device: {device}") def fit( self, From 42fb716e9f143704272b5b748c31a8ba5159448a Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:47:51 +0100 Subject: [PATCH 6/8] Implement Optuna callback. --- src/continuity/callbacks/__init__.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/continuity/callbacks/__init__.py b/src/continuity/callbacks/__init__.py index 2a2bb974..24e26962 100644 --- a/src/continuity/callbacks/__init__.py +++ b/src/continuity/callbacks/__init__.py @@ -111,3 +111,32 @@ def on_train_end(self): plt.ylabel("Loss") plt.legend(self.keys) plt.show() + + +class OptunaCallback(Callback): + """ + Callback to report intermediate values to Optuna. + + Args: + trial: Optuna trial. + """ + + def __init__(self, trial): + self.trial = trial + super().__init__() + + def __call__(self, epoch: int, logs: Dict[str, float]): + """Callback function. + Called at the end of each epoch. + + Args: + epoch: Current epoch. + logs: Dictionary of logs. + """ + self.trial.report(logs["loss/train"], step=epoch) + + def on_train_begin(self): + """Called at the beginning of training.""" + + def on_train_end(self): + """Called at the end of training.""" From 22aaec0abe7e1afe696caa0b464a93cdd0912f2c Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:48:16 +0100 Subject: [PATCH 7/8] Add optuna to dependencies. --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 12eda705..51ef5a16 100644 --- a/setup.cfg +++ b/setup.cfg @@ -74,6 +74,7 @@ install_requires = dadaptation==3.1 matplotlib pandas + optuna==3.5.0 [options.packages.find] From 67977268bfdade15cc693b03ed9f145d3dc68da4 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Fri, 9 Feb 2024 10:48:35 +0100 Subject: [PATCH 8/8] Add test_optuna. --- .gitignore | 1 + tests/test_optuna.py | 59 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 tests/test_optuna.py diff --git a/.gitignore b/.gitignore index 2f59bb83..4fbc0481 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,7 @@ dmypy.json # Temporary runs/ +*.db # Docs docs_build diff --git a/tests/test_optuna.py b/tests/test_optuna.py new file mode 100644 index 00000000..a67917f8 --- /dev/null +++ b/tests/test_optuna.py @@ -0,0 +1,59 @@ +import torch +from continuity.benchmarks.sine import SineBenchmark +from continuity.callbacks import OptunaCallback +from continuity.data import split, dataset_loss +from continuity.operators import DeepONet +import optuna + +# Set random seed +torch.manual_seed(0) + + +def test_optuna(): + def objective(trial): + trunk_width = trial.suggest_int("trunk_width", 4, 16) + trunk_depth = trial.suggest_int("trunk_depth", 4, 16) + num_epochs = trial.suggest_int("num_epochs", 1, 10) + lr = trial.suggest_float("lr", 1e-4, 1e-3) + + # Data set + benchmark = SineBenchmark() + + # Train/val split + train_dataset, val_dataset = split(benchmark.train_dataset, 0.9) + + # Operator + operator = DeepONet( + benchmark.dataset.num_sensors, + benchmark.dataset.coordinate_dim, + benchmark.dataset.num_channels, + trunk_width=trunk_width, + trunk_depth=trunk_depth, + ) + + # Optimizer + optimizer = torch.optim.Adam(operator.parameters(), lr=lr) + + operator.compile(optimizer, verbose=False) + operator.fit( + train_dataset, epochs=num_epochs, callbacks=[OptunaCallback(trial)] + ) + + loss_val = dataset_loss(val_dataset, operator, benchmark.metric()) + print(f"loss/val: {loss_val:.4e}") + + return loss_val + + # Run hyperparameter optimization + name = "test_optuna" + study = optuna.create_study( + direction="minimize", + study_name=name, + storage=f"sqlite:///{name}.db", + load_if_exists=True, + ) + study.optimize(objective, n_trials=10) + + +if __name__ == "__main__": + test_optuna()