Skip to content

Commit

Permalink
Make train function self-contained
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 3, 2024
1 parent dc2299b commit 0372828
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 36 deletions.
7 changes: 1 addition & 6 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,13 @@ def train_model(config: DictConfig) -> None:
logger.info("Setting up model")
architetcure_name = config["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architetcure_name}")
model = architecture.Model(
all_species=dataset.all_species,
hypers=OmegaConf.to_container(config["architecture"]["model"]),
)

logger.info("Run training")
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

model = architecture.train(
model=model,
train_dataset=dataset,
hypers=OmegaConf.to_container(config["architecture"]["training"]),
hypers=OmegaConf.to_container(config["architecture"]),
output_dir=output_dir,
)

Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401
from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401
from .model import Model, DEFAULT_HYPERS # noqa: F401
from .train import train # noqa: F401
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from ..utils.composition import apply_composition_contribution


DEAFAULT_HYPERS = OmegaConf.to_container(
DEFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)

DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"]
DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"]

ARCHITECTURE_NAME = "soap_bpnn"

Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model
from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model


def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model
from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model

from . import DATASET_PATH

Expand All @@ -13,7 +13,7 @@ def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)

structure = ase.io.read(DATASET_PATH)
original_structure = copy.deepcopy(structure)
Expand Down
16 changes: 5 additions & 11 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import (
DEFAULT_MODEL_HYPERS,
DEFAULT_TRAINING_HYPERS,
Model,
train,
)
from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model, train
from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets

Expand All @@ -21,7 +16,7 @@ def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)

# Predict on the first fivestructures
structures = ase.io.read(DATASET_PATH, ":5")
Expand All @@ -47,11 +42,10 @@ def test_regression_train():
targets = read_targets(DATASET_PATH, "U0")

dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

hypers = DEFAULT_TRAINING_HYPERS.copy()
hypers["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)
hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
soap_bpnn = train(dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch

from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model
from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)
torch.jit.script(soap_bpnn)
25 changes: 16 additions & 9 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@
from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import DEAFAULT_HYPERS
from .model import DEFAULT_HYPERS, Model


DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"]

logger = logging.getLogger(__name__)


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS, output_dir="."):
def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
# Calculate and set the composition weights:

model = Model(
all_species=train_dataset.all_species,
hypers=hypers["model"],
)

if len(train_dataset.targets) > 1:
raise ValueError(
f"`train_dataset` contains {len(train_dataset.targets)} targets but we "
Expand All @@ -32,22 +35,26 @@ def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS, output_dir="."):
composition_weights = calculate_composition_weights(train_dataset, target)
model.set_composition_weights(composition_weights)

hypers_training = hypers["training"]

# Create a dataloader for the training dataset:
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=hypers["batch_size"],
batch_size=hypers_training["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)

# Create an optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"])
optimizer = torch.optim.Adam(
model.parameters(), lr=hypers_training["learning_rate"]
)

# Train the model:
for epoch in range(hypers["num_epochs"]):
if epoch % hypers["log_interval"] == 0:
for epoch in range(hypers_training["num_epochs"]):
if epoch % hypers_training["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
model,
Path(output_dir) / f"model_{epoch}.pt",
Expand Down

0 comments on commit 0372828

Please sign in to comment.