Skip to content

Commit

Permalink
Implement saving and loading of models (#19)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Loche <[email protected]>
  • Loading branch information
frostedoyster and PicoCentauri authored Dec 11, 2023
1 parent 09c58cd commit 44e4f11
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 67 deletions.
5 changes: 3 additions & 2 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Utilitliy API
=============
Utility API
===========

This is the API for the ``utils`` module of ``metatensor-models``.

Expand All @@ -8,3 +8,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.

dataset
readers/index
model-io
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/model-io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Model IO
########

.. automodule:: metatensor.models.utils.model_io
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/src/models/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Hyperparameters

The hyperparameters (and relative default values) for the SOAP-BPNN model are:

.. literalinclude:: ../../../src/metatensor/models/soap_bpnn/default.yml
.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml
:language: yaml

Any of these hyperparameters can be overridden in the training configuration file.
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .model import Model # noqa: F401
from .train import train # noqa: F401
from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401
from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401
30 changes: 0 additions & 30 deletions src/metatensor/models/soap_bpnn/default.yml

This file was deleted.

17 changes: 16 additions & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from omegaconf import OmegaConf

from .. import ARCHITECTURE_CONFIG_PATH
from ..utils.composition import apply_composition_contribution


DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)

DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"]

ARCHITECTURE_NAME = "soap_bpnn"


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down Expand Up @@ -81,9 +92,13 @@ def forward(self, features: TensorMap) -> TensorMap:


class Model(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: Dict) -> None:
def __init__(
self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS
) -> None:
super().__init__()
self.name = ARCHITECTURE_NAME
self.all_species = all_species
self.hypers = hypers

# creates a composition weight tensor that can be directly indexed by species,
# this can be left as a tensor of zero or set from the outside using
Expand Down
7 changes: 0 additions & 7 deletions src/metatensor/models/soap_bpnn/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from pathlib import Path

from metatensor.models import ARCHITECTURE_CONFIG_PATH
from omegaconf import OmegaConf


DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)
DATASET_PATH = str(
Path(__file__).parent.resolve()
/ "../../../../../tests/resources/qm9_reduced_100.xyz"
Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])
6 changes: 3 additions & 3 deletions src/metatensor/models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.io.read(DATASET_PATH)
original_structure = copy.deepcopy(structure)
Expand Down
19 changes: 12 additions & 7 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model, train
from metatensor.models.soap_bpnn import (
DEFAULT_MODEL_HYPERS,
DEFAULT_TRAINING_HYPERS,
Model,
train,
)
from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


torch.manual_seed(0)
Expand All @@ -16,7 +21,7 @@ def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

# Predict on the first fivestructures
structures = ase.io.read(DATASET_PATH, ":5")
Expand All @@ -42,11 +47,11 @@ def test_regression_train():
targets = read_targets(DATASET_PATH, "U0")

dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

hypers_training = DEAFAULT_HYPERS["training"].copy()
hypers_training["num_epochs"] = 2
train(soap_bpnn, dataset, hypers_training)
hypers = DEFAULT_TRAINING_HYPERS.copy()
hypers["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
torch.jit.script(soap_bpnn)
19 changes: 13 additions & 6 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import DEAFAULT_HYPERS


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)

DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"]

logger = logging.getLogger(__name__)


def train(model, train_dataset, hypers):
def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS):
# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, "U0")
model.set_composition_weights(composition_weights)
Expand All @@ -34,7 +38,10 @@ def train(model, train_dataset, hypers):
if epoch % hypers["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
torch.save(model.state_dict(), f"model-{epoch}.pt")
save_model(
model,
f"model_{epoch}.pt",
)
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
Expand All @@ -44,4 +51,4 @@ def train(model, train_dataset, hypers):
optimizer.step()

# Save the model:
torch.save(model.state_dict(), "model_final.pt")
save_model(model, "model_final.pt")
56 changes: 56 additions & 0 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib

import torch


def save_model(
model: torch.nn.Module,
path: str,
) -> None:
"""Saves a model to a file, along with all the metadata needed to load it.
Parameters
----------
:param model: The model to save.
:param path: The path to the file where the model should be saved.
"""
torch.save(
{
"architecture_name": model.name,
"model_state_dict": model.state_dict(),
"model_hypers": model.hypers,
"all_species": model.all_species,
},
path,
)


def load_model(path: str) -> torch.nn.Module:
"""Loads a model from a file.
Parameters
----------
:param path: The path to the file containing the model.
Returns
-------
:return: The loaded model.
"""

# Load the model and the metadata
model_dict = torch.load(path)

# Get the architecture
architecture = importlib.import_module(
f"metatensor.models.{model_dict['architecture_name']}"
)

# Create the model
model = architecture.Model(
all_species=model_dict["all_species"], hypers=model_dict["model_hypers"]
)

# Load the model weights
model.load_state_dict(model_dict["model_state_dict"])

return model
29 changes: 29 additions & 0 deletions tests/model_io/test_model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path

import metatensor.torch
import rascaline.torch

from metatensor.models import soap_bpnn
from metatensor.models.utils.data import read_structures
from metatensor.models.utils.model_io import load_model, save_model


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_save_load_model():
"""Test that saving and loading a model works and preserves its internal state."""

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")

output_before_save = model(rascaline.torch.systems_to_torch(structures))

save_model(model, "test_model.pt")
loaded_model = load_model("test_model.pt")

output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures))

assert metatensor.torch.allclose(
output_before_save["energy"], output_after_load["energy"]
)

0 comments on commit 44e4f11

Please sign in to comment.