Skip to content

Commit

Permalink
Implement loader, tests, documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 8, 2023
1 parent 81871a0 commit 64bc443
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 15 deletions.
5 changes: 3 additions & 2 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Utilitliy API
=============
Utility API
===========

This is the API for the ``utils`` module of ``metatensor-models``.

Expand All @@ -8,3 +8,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.

dataset
readers/index
model-io
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/model-io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Model IO
########

.. automodule:: metatensor.models.utils.model_io
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/src/models/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Hyperparameters

The hyperparameters (and relative default values) for the SOAP-BPNN model are:

.. literalinclude:: ../../../src/metatensor/models/soap_bpnn/default.yml
.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml
:language: yaml

Any of these hyperparameters can be overridden in the training configuration file.
Expand Down
11 changes: 10 additions & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from omegaconf import OmegaConf

from metatensor.models import ARCHITECTURE_CONFIG_PATH

from ..utils.composition import apply_composition_contribution


ARCHITECTURE_NAME = "soap_bpnn"

DEFAULT_MODEL_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)["model"]


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
Expand Down Expand Up @@ -84,7 +91,9 @@ def forward(self, features: TensorMap) -> TensorMap:


class Model(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: Dict) -> None:
def __init__(
self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS
) -> None:
super().__init__()
self.all_species = all_species

Expand Down
6 changes: 3 additions & 3 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_regression_train():
dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)

hypers_training = DEAFAULT_HYPERS["training"].copy()
hypers_training["num_epochs"] = 2
train(soap_bpnn, dataset, hypers_training)
hypers = DEAFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
15 changes: 11 additions & 4 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
import logging

import torch
from omegaconf import OmegaConf

from metatensor.models import ARCHITECTURE_CONFIG_PATH

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import ARCHITECTURE_NAME


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)

DEFAULT_TRAINING_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)["training"]

logger = logging.getLogger(__name__)


def train(model, train_dataset, hypers):
def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS):
model_hypers = hypers["model"]
training_hypers = hypers["training"]

Expand Down
19 changes: 17 additions & 2 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from typing import Dict, List

import torch
Expand Down Expand Up @@ -44,5 +45,19 @@ def load_model(path: str) -> torch.nn.Module:
-------
torch.nn.Module: The loaded model.
"""
# TODO, possibly with hydra utilities?
pass

# Load the model and the metadata
model_dict = torch.load(path)

# Get the architecture
architecture = importlib.import_module(f"metatensor.models.{model_dict['name']}")

# Create the model
model = architecture.Model(
all_species=model_dict["all_species"], hypers=model_dict["hypers"]
)

# Load the model weights
model.load_state_dict(model_dict["model"])

return model
33 changes: 31 additions & 2 deletions tests/model_io/test_model_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
from pathlib import Path

import metatensor.torch
import rascaline.torch

from metatensor.models import soap_bpnn
from metatensor.models.soap_bpnn.model import DEFAULT_MODEL_HYPERS
from metatensor.models.utils.data import read_structures
from metatensor.models.utils.model_io import load_model, save_model


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_save_load_model():
"""Test that saving and loading a model works."""
pass
"""Test that saving and loading a model works and preserves its internal state."""

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")

output_before_save = model(rascaline.torch.systems_to_torch(structures))

save_model(
"soap_bpnn", model, DEFAULT_MODEL_HYPERS, model.all_species, "test_model.pt"
)
loaded_model = load_model("test_model.pt")

output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures))

assert metatensor.torch.allclose(
output_before_save["energy"], output_after_load["energy"]
)

0 comments on commit 64bc443

Please sign in to comment.