From 03728288287882d2e0c7128484c882e74a0d65a8 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 3 Jan 2024 13:00:04 +0100 Subject: [PATCH] Make train function self-contained --- src/metatensor/models/cli/train_model.py | 7 +----- src/metatensor/models/soap_bpnn/__init__.py | 4 +-- src/metatensor/models/soap_bpnn/model.py | 4 +-- .../soap_bpnn/tests/test_functionality.py | 4 +-- .../models/soap_bpnn/tests/test_invariance.py | 4 +-- .../models/soap_bpnn/tests/test_regression.py | 16 ++++-------- .../soap_bpnn/tests/test_torchscript.py | 4 +-- src/metatensor/models/soap_bpnn/train.py | 25 ++++++++++++------- 8 files changed, 32 insertions(+), 36 deletions(-) diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index 810e5f330..9688b2e41 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -102,18 +102,13 @@ def train_model(config: DictConfig) -> None: logger.info("Setting up model") architetcure_name = config["architecture"]["name"] architecture = importlib.import_module(f"metatensor.models.{architetcure_name}") - model = architecture.Model( - all_species=dataset.all_species, - hypers=OmegaConf.to_container(config["architecture"]["model"]), - ) logger.info("Run training") output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir model = architecture.train( - model=model, train_dataset=dataset, - hypers=OmegaConf.to_container(config["architecture"]["training"]), + hypers=OmegaConf.to_container(config["architecture"]), output_dir=output_dir, ) diff --git a/src/metatensor/models/soap_bpnn/__init__.py b/src/metatensor/models/soap_bpnn/__init__.py index 7b066fa9c..ff9a77daf 100644 --- a/src/metatensor/models/soap_bpnn/__init__.py +++ b/src/metatensor/models/soap_bpnn/__init__.py @@ -1,2 +1,2 @@ -from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401 -from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401 +from .model import Model, DEFAULT_HYPERS # noqa: F401 +from .train import train # noqa: F401 diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index a3bdf1af3..b257e1d4e 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -10,11 +10,11 @@ from ..utils.composition import apply_composition_contribution -DEAFAULT_HYPERS = OmegaConf.to_container( +DEFAULT_HYPERS = OmegaConf.to_container( OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") ) -DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"] +DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"] ARCHITECTURE_NAME = "soap_bpnn" diff --git a/src/metatensor/models/soap_bpnn/tests/test_functionality.py b/src/metatensor/models/soap_bpnn/tests/test_functionality.py index c6c6bcc0f..b94de2f46 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_functionality.py +++ b/src/metatensor/models/soap_bpnn/tests/test_functionality.py @@ -2,7 +2,7 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model +from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model def test_prediction_subset(): @@ -10,7 +10,7 @@ def test_prediction_subset(): of the elements it was trained on.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64) structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) soap_bpnn([rascaline.torch.systems_to_torch(structure)]) diff --git a/src/metatensor/models/soap_bpnn/tests/test_invariance.py b/src/metatensor/models/soap_bpnn/tests/test_invariance.py index 87358a9ab..8b91b276e 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_invariance.py +++ b/src/metatensor/models/soap_bpnn/tests/test_invariance.py @@ -4,7 +4,7 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model +from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model from . import DATASET_PATH @@ -13,7 +13,7 @@ def test_rotational_invariance(): """Tests that the model is rotationally invariant.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64) structure = ase.io.read(DATASET_PATH) original_structure = copy.deepcopy(structure) diff --git a/src/metatensor/models/soap_bpnn/tests/test_regression.py b/src/metatensor/models/soap_bpnn/tests/test_regression.py index 0bf946920..56ee49e04 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/soap_bpnn/tests/test_regression.py @@ -2,12 +2,7 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import ( - DEFAULT_MODEL_HYPERS, - DEFAULT_TRAINING_HYPERS, - Model, - train, -) +from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model, train from metatensor.models.utils.data import Dataset from metatensor.models.utils.data.readers import read_structures, read_targets @@ -21,7 +16,7 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64) # Predict on the first fivestructures structures = ase.io.read(DATASET_PATH, ":5") @@ -47,11 +42,10 @@ def test_regression_train(): targets = read_targets(DATASET_PATH, "U0") dataset = Dataset(structures, targets) - soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) - hypers = DEFAULT_TRAINING_HYPERS.copy() - hypers["num_epochs"] = 2 - train(soap_bpnn, dataset, hypers) + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["num_epochs"] = 2 + soap_bpnn = train(dataset, hypers) # Predict on the first five structures output = soap_bpnn(structures[:5]) diff --git a/src/metatensor/models/soap_bpnn/tests/test_torchscript.py b/src/metatensor/models/soap_bpnn/tests/test_torchscript.py index 22efa8e81..4e453c77b 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_torchscript.py +++ b/src/metatensor/models/soap_bpnn/tests/test_torchscript.py @@ -1,11 +1,11 @@ import torch -from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model +from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model def test_torchscript(): """Tests that the model can be jitted.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64) torch.jit.script(soap_bpnn) diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index 3758ba4b6..b646ac194 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -6,11 +6,9 @@ from ..utils.composition import calculate_composition_weights from ..utils.data import collate_fn from ..utils.model_io import save_model -from .model import DEAFAULT_HYPERS +from .model import DEFAULT_HYPERS, Model -DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"] - logger = logging.getLogger(__name__) @@ -18,9 +16,14 @@ def loss_function(predicted, target): return torch.sum((predicted.block().values - target.block().values) ** 2) -def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS, output_dir="."): +def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."): # Calculate and set the composition weights: + model = Model( + all_species=train_dataset.all_species, + hypers=hypers["model"], + ) + if len(train_dataset.targets) > 1: raise ValueError( f"`train_dataset` contains {len(train_dataset.targets)} targets but we " @@ -32,22 +35,26 @@ def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS, output_dir="."): composition_weights = calculate_composition_weights(train_dataset, target) model.set_composition_weights(composition_weights) + hypers_training = hypers["training"] + # Create a dataloader for the training dataset: train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=hypers["batch_size"], + batch_size=hypers_training["batch_size"], shuffle=True, collate_fn=collate_fn, ) # Create an optimizer: - optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"]) + optimizer = torch.optim.Adam( + model.parameters(), lr=hypers_training["learning_rate"] + ) # Train the model: - for epoch in range(hypers["num_epochs"]): - if epoch % hypers["log_interval"] == 0: + for epoch in range(hypers_training["num_epochs"]): + if epoch % hypers_training["log_interval"] == 0: logger.info(f"Epoch {epoch}") - if epoch % hypers["checkpoint_interval"] == 0: + if epoch % hypers_training["checkpoint_interval"] == 0: save_model( model, Path(output_dir) / f"model_{epoch}.pt",