Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement saving and loading of models #19

Merged
merged 8 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Utilitliy API
=============
Utility API
===========

This is the API for the ``utils`` module of ``metatensor-models``.

Expand All @@ -8,3 +8,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.

dataset
readers/index
model-io
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/model-io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Model IO
########

.. automodule:: metatensor.models.utils.model_io
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/src/models/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Hyperparameters

The hyperparameters (and relative default values) for the SOAP-BPNN model are:

.. literalinclude:: ../../../src/metatensor/models/soap_bpnn/default.yml
.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml
:language: yaml

Any of these hyperparameters can be overridden in the training configuration file.
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .model import Model # noqa: F401
from .train import train # noqa: F401
from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401
from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401
30 changes: 0 additions & 30 deletions src/metatensor/models/soap_bpnn/default.yml

This file was deleted.

17 changes: 16 additions & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from omegaconf import OmegaConf

from .. import ARCHITECTURE_CONFIG_PATH
from ..utils.composition import apply_composition_contribution


DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)

DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"]

ARCHITECTURE_NAME = "soap_bpnn"


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down Expand Up @@ -81,9 +92,13 @@ def forward(self, features: TensorMap) -> TensorMap:


class Model(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: Dict) -> None:
def __init__(
self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS
) -> None:
super().__init__()
self.name = ARCHITECTURE_NAME
self.all_species = all_species
self.hypers = hypers

# creates a composition weight tensor that can be directly indexed by species,
# this can be left as a tensor of zero or set from the outside using
Expand Down
7 changes: 0 additions & 7 deletions src/metatensor/models/soap_bpnn/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from pathlib import Path

from metatensor.models import ARCHITECTURE_CONFIG_PATH
from omegaconf import OmegaConf


DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)
DATASET_PATH = str(
Path(__file__).parent.resolve()
/ "../../../../../tests/resources/qm9_reduced_100.xyz"
Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])
6 changes: 3 additions & 3 deletions src/metatensor/models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.io.read(DATASET_PATH)
original_structure = copy.deepcopy(structure)
Expand Down
19 changes: 12 additions & 7 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model, train
from metatensor.models.soap_bpnn import (
DEFAULT_MODEL_HYPERS,
DEFAULT_TRAINING_HYPERS,
Model,
train,
)
from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


torch.manual_seed(0)
Expand All @@ -16,7 +21,7 @@ def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

# Predict on the first fivestructures
structures = ase.io.read(DATASET_PATH, ":5")
Expand All @@ -42,11 +47,11 @@ def test_regression_train():
targets = read_targets(DATASET_PATH, "U0")

dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

hypers_training = DEAFAULT_HYPERS["training"].copy()
hypers_training["num_epochs"] = 2
train(soap_bpnn, dataset, hypers_training)
hypers = DEFAULT_TRAINING_HYPERS.copy()
hypers["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
torch.jit.script(soap_bpnn)
19 changes: 13 additions & 6 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import DEAFAULT_HYPERS


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)

DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"]

logger = logging.getLogger(__name__)


def train(model, train_dataset, hypers):
def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS):
# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, "U0")
model.set_composition_weights(composition_weights)
Expand All @@ -34,7 +38,10 @@ def train(model, train_dataset, hypers):
if epoch % hypers["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
torch.save(model.state_dict(), f"model-{epoch}.pt")
save_model(
model,
f"model_{epoch}.pt",
)
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
Expand All @@ -44,4 +51,4 @@ def train(model, train_dataset, hypers):
optimizer.step()

# Save the model:
torch.save(model.state_dict(), "model_final.pt")
save_model(model, "model_final.pt")
56 changes: 56 additions & 0 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib

import torch


def save_model(
model: torch.nn.Module,
path: str,
) -> None:
"""Saves a model to a file, along with all the metadata needed to load it.
Parameters
----------
:param model: The model to save.
:param path: The path to the file where the model should be saved.
"""
torch.save(
{
"architecture_name": model.name,
"model_state_dict": model.state_dict(),
"model_hypers": model.hypers,
"all_species": model.all_species,
},
path,
)


def load_model(path: str) -> torch.nn.Module:
"""Loads a model from a file.
Parameters
----------
:param path: The path to the file containing the model.
Returns
-------
:return: The loaded model.
"""

# Load the model and the metadata
model_dict = torch.load(path)

# Get the architecture
architecture = importlib.import_module(
f"metatensor.models.{model_dict['architecture_name']}"
)

# Create the model
model = architecture.Model(
all_species=model_dict["all_species"], hypers=model_dict["model_hypers"]
)

# Load the model weights
model.load_state_dict(model_dict["model_state_dict"])

return model
29 changes: 29 additions & 0 deletions tests/model_io/test_model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path

import metatensor.torch
import rascaline.torch

from metatensor.models import soap_bpnn
from metatensor.models.utils.data import read_structures
from metatensor.models.utils.model_io import load_model, save_model


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_save_load_model():
"""Test that saving and loading a model works and preserves its internal state."""

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")

output_before_save = model(rascaline.torch.systems_to_torch(structures))

save_model(model, "test_model.pt")
loaded_model = load_model("test_model.pt")

output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures))

assert metatensor.torch.allclose(
output_before_save["energy"], output_after_load["energy"]
)