From 64bc443e3e05c82b4800a3c9e7a3ff8f2e0bc547 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 8 Dec 2023 12:11:00 +0100 Subject: [PATCH] Implement loader, tests, documentation --- docs/src/dev-docs/utils/index.rst | 5 +-- docs/src/dev-docs/utils/model-io.rst | 7 ++++ docs/src/models/soap-bpnn.rst | 2 +- src/metatensor/models/soap_bpnn/model.py | 11 ++++++- .../models/soap_bpnn/tests/test_regression.py | 6 ++-- src/metatensor/models/soap_bpnn/train.py | 15 ++++++--- src/metatensor/models/utils/model_io.py | 19 +++++++++-- tests/model_io/test_model_io.py | 33 +++++++++++++++++-- 8 files changed, 83 insertions(+), 15 deletions(-) create mode 100644 docs/src/dev-docs/utils/model-io.rst diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index e01d3bbc4..6e6d48370 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -1,5 +1,5 @@ -Utilitliy API -============= +Utility API +=========== This is the API for the ``utils`` module of ``metatensor-models``. @@ -8,3 +8,4 @@ This is the API for the ``utils`` module of ``metatensor-models``. dataset readers/index + model-io diff --git a/docs/src/dev-docs/utils/model-io.rst b/docs/src/dev-docs/utils/model-io.rst new file mode 100644 index 000000000..14b6e6d3d --- /dev/null +++ b/docs/src/dev-docs/utils/model-io.rst @@ -0,0 +1,7 @@ +Model IO +######## + +.. automodule:: metatensor.models.utils.model_io + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/models/soap-bpnn.rst b/docs/src/models/soap-bpnn.rst index 1438db587..06a431cfe 100644 --- a/docs/src/models/soap-bpnn.rst +++ b/docs/src/models/soap-bpnn.rst @@ -22,7 +22,7 @@ Hyperparameters The hyperparameters (and relative default values) for the SOAP-BPNN model are: -.. literalinclude:: ../../../src/metatensor/models/soap_bpnn/default.yml +.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml :language: yaml Any of these hyperparameters can be overridden in the training configuration file. diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index 4416fcd87..8c781e364 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -4,12 +4,19 @@ import rascaline.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from omegaconf import OmegaConf + +from metatensor.models import ARCHITECTURE_CONFIG_PATH from ..utils.composition import apply_composition_contribution ARCHITECTURE_NAME = "soap_bpnn" +DEFAULT_MODEL_HYPERS = OmegaConf.to_container( + OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") +)["model"] + class MLPMap(torch.nn.Module): def __init__(self, all_species: List[int], hypers: dict) -> None: @@ -84,7 +91,9 @@ def forward(self, features: TensorMap) -> TensorMap: class Model(torch.nn.Module): - def __init__(self, all_species: List[int], hypers: Dict) -> None: + def __init__( + self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS + ) -> None: super().__init__() self.all_species = all_species diff --git a/src/metatensor/models/soap_bpnn/tests/test_regression.py b/src/metatensor/models/soap_bpnn/tests/test_regression.py index 7bdc31234..dca83bfca 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/soap_bpnn/tests/test_regression.py @@ -44,9 +44,9 @@ def test_regression_train(): dataset = Dataset(structures, targets) soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) - hypers_training = DEAFAULT_HYPERS["training"].copy() - hypers_training["num_epochs"] = 2 - train(soap_bpnn, dataset, hypers_training) + hypers = DEAFAULT_HYPERS.copy() + hypers["training"]["num_epochs"] = 2 + train(soap_bpnn, dataset, hypers) # Predict on the first five structures output = soap_bpnn(structures[:5]) diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index bc120a137..1c944bcfe 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -1,6 +1,9 @@ import logging import torch +from omegaconf import OmegaConf + +from metatensor.models import ARCHITECTURE_CONFIG_PATH from ..utils.composition import calculate_composition_weights from ..utils.data import collate_fn @@ -8,14 +11,18 @@ from .model import ARCHITECTURE_NAME -def loss_function(predicted, target): - return torch.sum((predicted.block().values - target.block().values) ** 2) - +DEFAULT_TRAINING_HYPERS = OmegaConf.to_container( + OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") +)["training"] logger = logging.getLogger(__name__) -def train(model, train_dataset, hypers): +def loss_function(predicted, target): + return torch.sum((predicted.block().values - target.block().values) ** 2) + + +def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS): model_hypers = hypers["model"] training_hypers = hypers["training"] diff --git a/src/metatensor/models/utils/model_io.py b/src/metatensor/models/utils/model_io.py index 20358c7ef..1ef6c1c1a 100644 --- a/src/metatensor/models/utils/model_io.py +++ b/src/metatensor/models/utils/model_io.py @@ -1,3 +1,4 @@ +import importlib from typing import Dict, List import torch @@ -44,5 +45,19 @@ def load_model(path: str) -> torch.nn.Module: ------- torch.nn.Module: The loaded model. """ - # TODO, possibly with hydra utilities? - pass + + # Load the model and the metadata + model_dict = torch.load(path) + + # Get the architecture + architecture = importlib.import_module(f"metatensor.models.{model_dict['name']}") + + # Create the model + model = architecture.Model( + all_species=model_dict["all_species"], hypers=model_dict["hypers"] + ) + + # Load the model weights + model.load_state_dict(model_dict["model"]) + + return model diff --git a/tests/model_io/test_model_io.py b/tests/model_io/test_model_io.py index 9032a90b5..9748f53f2 100644 --- a/tests/model_io/test_model_io.py +++ b/tests/model_io/test_model_io.py @@ -1,3 +1,32 @@ +from pathlib import Path + +import metatensor.torch +import rascaline.torch + +from metatensor.models import soap_bpnn +from metatensor.models.soap_bpnn.model import DEFAULT_MODEL_HYPERS +from metatensor.models.utils.data import read_structures +from metatensor.models.utils.model_io import load_model, save_model + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + def test_save_load_model(): - """Test that saving and loading a model works.""" - pass + """Test that saving and loading a model works and preserves its internal state.""" + + model = soap_bpnn.Model(all_species=[1, 6, 7, 8]) + structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz") + + output_before_save = model(rascaline.torch.systems_to_torch(structures)) + + save_model( + "soap_bpnn", model, DEFAULT_MODEL_HYPERS, model.all_species, "test_model.pt" + ) + loaded_model = load_model("test_model.pt") + + output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures)) + + assert metatensor.torch.allclose( + output_before_save["energy"], output_after_load["energy"] + )