Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement saving and loading of models #19

Merged
merged 8 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Utilitliy API
=============
Utility API
===========

This is the API for the ``utils`` module of ``metatensor-models``.

Expand All @@ -8,3 +8,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.

dataset
readers/index
model-io
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/model-io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Model IO
########

.. automodule:: metatensor.models.utils.model_io
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/src/models/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Hyperparameters

The hyperparameters (and relative default values) for the SOAP-BPNN model are:

.. literalinclude:: ../../../src/metatensor/models/soap_bpnn/default.yml
.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml
:language: yaml

Any of these hyperparameters can be overridden in the training configuration file.
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def train_model(config: DictConfig) -> None:
architecture.train(
model=model,
train_dataset=dataset,
hypers=OmegaConf.to_container(config["architecture"]["training"]),
hypers=OmegaConf.to_container(config["architecture"]),
)
30 changes: 0 additions & 30 deletions src/metatensor/models/soap_bpnn/default.yml

This file was deleted.

14 changes: 13 additions & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from omegaconf import OmegaConf

from metatensor.models import ARCHITECTURE_CONFIG_PATH

from ..utils.composition import apply_composition_contribution


ARCHITECTURE_NAME = "soap_bpnn"

DEFAULT_MODEL_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)["model"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't have to set hypers to a default value or is there a specific reason?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this way, it's much easier to use the model from Python. For example, the test I've written for the model_io module would have been much more difficult. In general, I feel like, if we want to support model use from Python as well, we need to make the hypers available from Python as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I get your point but then maybe it makes sense to provide a general parser and store the hypers as dictionary in the init of each architecture.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is short enough for the moment... Perhaps we can change it once we have more models?



class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down Expand Up @@ -81,7 +91,9 @@ def forward(self, features: TensorMap) -> TensorMap:


class Model(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: Dict) -> None:
def __init__(
self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS
) -> None:
super().__init__()
self.all_species = all_species

Expand Down
6 changes: 3 additions & 3 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_regression_train():
dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)

hypers_training = DEAFAULT_HYPERS["training"].copy()
hypers_training["num_epochs"] = 2
train(soap_bpnn, dataset, hypers_training)
hypers = DEAFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
42 changes: 32 additions & 10 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,60 @@
import logging

import torch
from omegaconf import OmegaConf

from metatensor.models import ARCHITECTURE_CONFIG_PATH

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import ARCHITECTURE_NAME


DEFAULT_TRAINING_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)["training"]

logger = logging.getLogger(__name__)


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


logger = logging.getLogger(__name__)

def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS):
model_hypers = hypers["model"]
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
training_hypers = hypers["training"]

def train(model, train_dataset, hypers):
# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, "U0")
model.set_composition_weights(composition_weights)

# Create a dataloader for the training dataset:
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=hypers["batch_size"],
batch_size=training_hypers["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)

# Create an optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"])
optimizer = torch.optim.Adam(
model.parameters(), lr=training_hypers["learning_rate"]
)

# Train the model:
for epoch in range(hypers["num_epochs"]):
if epoch % hypers["log_interval"] == 0:
for epoch in range(training_hypers["num_epochs"]):
if epoch % training_hypers["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
torch.save(model.state_dict(), f"model-{epoch}.pt")
if epoch % training_hypers["checkpoint_interval"] == 0:
save_model(
ARCHITECTURE_NAME,
model,
model_hypers,
model.all_species,
f"model_{epoch}.pt",
)
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
Expand All @@ -44,4 +64,6 @@ def train(model, train_dataset, hypers):
optimizer.step()

# Save the model:
torch.save(model.state_dict(), "model_final.pt")
save_model(
ARCHITECTURE_NAME, model, model_hypers, model.all_species, "model_final.pt"
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
)
65 changes: 65 additions & 0 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib
from typing import Dict, List

import torch


def save_model(
arch_name: str,
model: torch.nn.Module,
hypers: Dict,
all_species: List[int],
path: str,
) -> None:
"""Saves a model to a file, along with all the metadata needed to load it.

Parameters
----------
arch_name (str): The name of the architecture.

model (torch.nn.Module): The model to save.

hypers (Dict): The hyperparameters used to train the model.

all_species (List[int]): The list of all species that the model can handle.
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved

path (str): The path to the file.
"""
torch.save(
{
"name": arch_name,
"model": model.state_dict(),
"hypers": hypers,
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
"all_species": all_species,
},
path,
)


def load_model(path: str) -> torch.nn.Module:
"""Loads a model from a file.

Parameters
----------
path (str): The path to the file.

Returns
-------
torch.nn.Module: The loaded model.
"""

# Load the model and the metadata
model_dict = torch.load(path)

# Get the architecture
architecture = importlib.import_module(f"metatensor.models.{model_dict['name']}")

# Create the model
model = architecture.Model(
all_species=model_dict["all_species"], hypers=model_dict["hypers"]
)

# Load the model weights
model.load_state_dict(model_dict["model"])

return model
32 changes: 32 additions & 0 deletions tests/model_io/test_model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

import metatensor.torch
import rascaline.torch

from metatensor.models import soap_bpnn
from metatensor.models.soap_bpnn.model import DEFAULT_MODEL_HYPERS
from metatensor.models.utils.data import read_structures
from metatensor.models.utils.model_io import load_model, save_model


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_save_load_model():
"""Test that saving and loading a model works and preserves its internal state."""

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")

output_before_save = model(rascaline.torch.systems_to_torch(structures))

save_model(
"soap_bpnn", model, DEFAULT_MODEL_HYPERS, model.all_species, "test_model.pt"
)
loaded_model = load_model("test_model.pt")

output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures))

assert metatensor.torch.allclose(
output_before_save["energy"], output_after_load["energy"]
)