-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement saving and loading of models (#19)
Co-authored-by: Philip Loche <[email protected]>
- Loading branch information
1 parent
09c58cd
commit 44e4f11
Showing
14 changed files
with
146 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Model IO | ||
######## | ||
|
||
.. automodule:: metatensor.models.utils.model_io | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .model import Model # noqa: F401 | ||
from .train import train # noqa: F401 | ||
from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401 | ||
from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,11 @@ | ||
import torch | ||
|
||
from metatensor.models.soap_bpnn import Model | ||
|
||
from . import DEAFAULT_HYPERS | ||
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model | ||
|
||
|
||
def test_torchscript(): | ||
"""Tests that the model can be jitted.""" | ||
|
||
all_species = [1, 6, 7, 8] | ||
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) | ||
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) | ||
torch.jit.script(soap_bpnn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import importlib | ||
|
||
import torch | ||
|
||
|
||
def save_model( | ||
model: torch.nn.Module, | ||
path: str, | ||
) -> None: | ||
"""Saves a model to a file, along with all the metadata needed to load it. | ||
Parameters | ||
---------- | ||
:param model: The model to save. | ||
:param path: The path to the file where the model should be saved. | ||
""" | ||
torch.save( | ||
{ | ||
"architecture_name": model.name, | ||
"model_state_dict": model.state_dict(), | ||
"model_hypers": model.hypers, | ||
"all_species": model.all_species, | ||
}, | ||
path, | ||
) | ||
|
||
|
||
def load_model(path: str) -> torch.nn.Module: | ||
"""Loads a model from a file. | ||
Parameters | ||
---------- | ||
:param path: The path to the file containing the model. | ||
Returns | ||
------- | ||
:return: The loaded model. | ||
""" | ||
|
||
# Load the model and the metadata | ||
model_dict = torch.load(path) | ||
|
||
# Get the architecture | ||
architecture = importlib.import_module( | ||
f"metatensor.models.{model_dict['architecture_name']}" | ||
) | ||
|
||
# Create the model | ||
model = architecture.Model( | ||
all_species=model_dict["all_species"], hypers=model_dict["model_hypers"] | ||
) | ||
|
||
# Load the model weights | ||
model.load_state_dict(model_dict["model_state_dict"]) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from pathlib import Path | ||
|
||
import metatensor.torch | ||
import rascaline.torch | ||
|
||
from metatensor.models import soap_bpnn | ||
from metatensor.models.utils.data import read_structures | ||
from metatensor.models.utils.model_io import load_model, save_model | ||
|
||
|
||
RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" | ||
|
||
|
||
def test_save_load_model(): | ||
"""Test that saving and loading a model works and preserves its internal state.""" | ||
|
||
model = soap_bpnn.Model(all_species=[1, 6, 7, 8]) | ||
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz") | ||
|
||
output_before_save = model(rascaline.torch.systems_to_torch(structures)) | ||
|
||
save_model(model, "test_model.pt") | ||
loaded_model = load_model("test_model.pt") | ||
|
||
output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures)) | ||
|
||
assert metatensor.torch.allclose( | ||
output_before_save["energy"], output_after_load["energy"] | ||
) |