-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement saving and loading of models #19
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
01bcf05
First attempt
frostedoyster 81871a0
Feed hypers of the model and training to trainer
frostedoyster 64bc443
Implement loader, tests, documentation
frostedoyster 98e9436
Fix two small bugs
frostedoyster 5f8008d
Fix final bug
frostedoyster e64c66e
Apply suggestions
frostedoyster 07d1f23
Unify hypers in BPNN
PicoCentauri b22aee6
move hyper consts to unified places
PicoCentauri File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Model IO | ||
######## | ||
|
||
.. automodule:: metatensor.models.utils.model_io | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import importlib | ||
from typing import Dict, List | ||
|
||
import torch | ||
|
||
|
||
def save_model( | ||
arch_name: str, | ||
model: torch.nn.Module, | ||
hypers: Dict, | ||
all_species: List[int], | ||
path: str, | ||
) -> None: | ||
"""Saves a model to a file, along with all the metadata needed to load it. | ||
|
||
Parameters | ||
---------- | ||
arch_name (str): The name of the architecture. | ||
|
||
model (torch.nn.Module): The model to save. | ||
|
||
hypers (Dict): The hyperparameters used to train the model. | ||
|
||
all_species (List[int]): The list of all species that the model can handle. | ||
frostedoyster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
path (str): The path to the file. | ||
""" | ||
torch.save( | ||
{ | ||
"name": arch_name, | ||
"model": model.state_dict(), | ||
"hypers": hypers, | ||
frostedoyster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"all_species": all_species, | ||
}, | ||
path, | ||
) | ||
|
||
|
||
def load_model(path: str) -> torch.nn.Module: | ||
"""Loads a model from a file. | ||
|
||
Parameters | ||
---------- | ||
path (str): The path to the file. | ||
|
||
Returns | ||
------- | ||
torch.nn.Module: The loaded model. | ||
""" | ||
|
||
# Load the model and the metadata | ||
model_dict = torch.load(path) | ||
|
||
# Get the architecture | ||
architecture = importlib.import_module(f"metatensor.models.{model_dict['name']}") | ||
|
||
# Create the model | ||
model = architecture.Model( | ||
all_species=model_dict["all_species"], hypers=model_dict["hypers"] | ||
) | ||
|
||
# Load the model weights | ||
model.load_state_dict(model_dict["model"]) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from pathlib import Path | ||
|
||
import metatensor.torch | ||
import rascaline.torch | ||
|
||
from metatensor.models import soap_bpnn | ||
from metatensor.models.soap_bpnn.model import DEFAULT_MODEL_HYPERS | ||
from metatensor.models.utils.data import read_structures | ||
from metatensor.models.utils.model_io import load_model, save_model | ||
|
||
|
||
RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" | ||
|
||
|
||
def test_save_load_model(): | ||
"""Test that saving and loading a model works and preserves its internal state.""" | ||
|
||
model = soap_bpnn.Model(all_species=[1, 6, 7, 8]) | ||
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz") | ||
|
||
output_before_save = model(rascaline.torch.systems_to_torch(structures)) | ||
|
||
save_model( | ||
"soap_bpnn", model, DEFAULT_MODEL_HYPERS, model.all_species, "test_model.pt" | ||
) | ||
loaded_model = load_model("test_model.pt") | ||
|
||
output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures)) | ||
|
||
assert metatensor.torch.allclose( | ||
output_before_save["energy"], output_after_load["energy"] | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't have to set hypers to a default value or is there a specific reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this way, it's much easier to use the model from Python. For example, the test I've written for the model_io module would have been much more difficult. In general, I feel like, if we want to support model use from Python as well, we need to make the hypers available from Python as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I get your point but then maybe it makes sense to provide a general parser and store the hypers as dictionary in the init of each architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is short enough for the moment... Perhaps we can change it once we have more models?