Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize training procedure #31

Merged
merged 25 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3ea3e6a
Add gradient calculator
frostedoyster Jan 4, 2024
79d161e
Fix linter
frostedoyster Jan 4, 2024
e65679d
Loss draft
frostedoyster Jan 11, 2024
8591d7d
Add tests for losses
frostedoyster Jan 11, 2024
8df7e41
Wrap forces and stresses
frostedoyster Jan 11, 2024
0ffc0a9
Clarify cell convention
frostedoyster Jan 11, 2024
9253931
Merge branch 'main' into forces-virials
frostedoyster Jan 11, 2024
9b93eaa
Support multiple model outputs, use new loss
frostedoyster Jan 11, 2024
83c94e6
Fix composition calculator
frostedoyster Jan 11, 2024
4f1c569
Address review
frostedoyster Jan 12, 2024
7c32d9e
Partial draft
frostedoyster Jan 12, 2024
3081a0f
Finished trainer
frostedoyster Jan 12, 2024
ed8697f
Make linter happy
frostedoyster Jan 12, 2024
589b586
Merge branch 'main' into finalize-training
frostedoyster Jan 13, 2024
01a528b
Fix small merge issue
frostedoyster Jan 12, 2024
aa31433
Add new functions to the documentation
frostedoyster Jan 16, 2024
9061f1c
Add tutorial how to override arch params
PicoCentauri Jan 18, 2024
631d274
Adapt to most recent parser changes
frostedoyster Jan 19, 2024
dc523b7
Train with actual train/validation splits
frostedoyster Jan 19, 2024
fe0bd01
Fix some small issues
frostedoyster Jan 19, 2024
f579652
Merge branch 'arch_override_docs' into finalize-training
frostedoyster Jan 20, 2024
2e82d22
Fix docs build?
frostedoyster Jan 19, 2024
59cf4f4
Address reviewer comments
frostedoyster Jan 23, 2024
bc60b87
Merge branch 'main' into finalize-training
frostedoyster Jan 23, 2024
22237a4
Merge branch 'finalize-training' of https://github.com/lab-cosmo/meta…
frostedoyster Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/combine_dataloaders.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Combining dataloaders
#####################

.. automodule:: metatensor.models.utils.data.combine_dataloaders
:members:
:undoc-members:
:show-inheritance:
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.
writers
model-io
omegaconf
combine_dataloaders
39 changes: 26 additions & 13 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import hydra
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from omegaconf import DictConfig, OmegaConf

from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets

from .. import CONFIG_PATH
from ..utils.data import get_all_species
from ..utils.model_io import save_model
from ..utils.omegaconf import expand_dataset_config
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -174,19 +176,30 @@ def train_model(options: DictConfig) -> None:
logger.info("Run training")
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

# HACK: Avoid passing a Subset which we can not handle yet. For now we pass
# the complete training set even though it was split before...
if isinstance(train_dataset, torch.utils.data.Subset):
model = architecture.train(
train_dataset=train_dataset.dataset,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
)
else:
model = architecture.train(
train_dataset=train_dataset,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
all_species = []
for dataset in [train_dataset]: # HACK: only a single train_dataset for now
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
all_species += get_all_species(dataset)
all_species = list(set(all_species))

outputs = {
key: ModelOutput(
quantity=value["quantity"],
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
unit=(value["unit"] if value["unit"] is not None else ""), # potential HACK
)
for key, value in options["training_set"]["targets"].items()
}
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
species=all_species,
outputs=outputs,
)

model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
model_capabilities=model_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
)

save_model(model, options["output_path"])
36 changes: 27 additions & 9 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import random

import ase.io
import numpy as np
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from omegaconf import OmegaConf

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model, train
from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data import Dataset, get_all_species
from metatensor.models.utils.data.readers import read_structures, read_targets

from . import DATASET_PATH


# reproducibility
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)


Expand All @@ -21,7 +27,7 @@ def test_regression_init():
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
"U0": ModelOutput(
quantity="energy",
unit="eV",
)
Expand All @@ -33,14 +39,15 @@ def test_regression_init():
structures = ase.io.read(DATASET_PATH, ":5")

output = soap_bpnn(
[rascaline.torch.systems_to_torch(structure) for structure in structures]
[rascaline.torch.systems_to_torch(structure) for structure in structures],
["U0"],
)
expected_output = torch.tensor(
[[-0.4615], [-0.4367], [-0.3004], [-0.2606], [-0.2380]],
dtype=torch.float64,
)

assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)
assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)


def test_regression_train():
Expand All @@ -50,7 +57,7 @@ def test_regression_train():
structures = read_structures(DATASET_PATH)

conf = {
"energy": {
"U0": {
"quantity": "energy",
"read_from": DATASET_PATH,
"file_format": ".xyz",
Expand All @@ -66,14 +73,25 @@ def test_regression_train():

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
soap_bpnn = train(dataset, hypers)

capabilities = ModelCapabilities(
length_unit="Angstrom",
species=get_all_species(dataset),
outputs={
"U0": ModelOutput(
quantity="energy",
unit="eV",
)
},
)
soap_bpnn = train([dataset], [dataset], capabilities, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
output = soap_bpnn(structures[:5], ["U0"])

expected_output = torch.tensor(
[[-39.9658], [-56.0888], [-76.1100], [-76.9461], [-93.0914]],
[[-40.1358], [-56.1721], [-76.1576], [-77.1174], [-93.1679]],
dtype=torch.float64,
)

assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)
assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)
182 changes: 144 additions & 38 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import logging
from pathlib import Path
from typing import Dict, List, Union

import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities

from ..utils.composition import calculate_composition_weights
from ..utils.compute_loss import compute_model_loss
from ..utils.data import collate_fn
from ..utils.data import (
Dataset,
check_datasets,
collate_fn,
combine_dataloaders,
get_all_targets,
)
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model
Expand All @@ -15,25 +22,19 @@
logger = logging.getLogger(__name__)


def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
if len(train_dataset.targets) > 1:
raise ValueError(
f"`train_dataset` contains {len(train_dataset.targets)} targets but we "
"currently only support a single target value!"
)
else:
target_name = list(train_dataset.targets.keys())[0]

# Set the model's capabilities:
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
species=train_dataset.all_species,
outputs={
target_name: ModelOutput(
quantity="energy",
unit="eV",
)
},
def train(
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
validation_datasets: List[Union[Dataset, torch.utils.data.Subset]],
model_capabilities: ModelCapabilities,
hypers: Dict = DEFAULT_HYPERS,
output_dir: str = ".",
):
# Perform canonical checks on the datasets:
logger.info("Checking datasets for consistency")
check_datasets(
train_datasets,
validation_datasets,
model_capabilities,
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
)

# Create the model:
Expand All @@ -42,19 +43,69 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
hypers=hypers["model"],
)

# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, target_name)
model.set_composition_weights(target_name, composition_weights)
# Calculate and set the composition weights for all targets:
for target_name in model_capabilities.outputs.keys():
# find the dataset that contains the target:
train_dataset_with_target = None
for dataset in train_datasets:
if target_name in get_all_targets(dataset):
train_dataset_with_target = dataset
break
if train_dataset_with_target is None:
raise ValueError(
f"Target {target_name} in the model's capabilities is not "
"present in any of the training datasets."
)
composition_weights = calculate_composition_weights(
train_dataset_with_target, target_name
)
model.set_composition_weights(target_name, composition_weights)

hypers_training = hypers["training"]

# Create a dataloader for the training dataset:
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=hypers_training["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)
# Create dataloader for the training datasets:
train_dataloaders = []
for dataset in train_datasets:
train_dataloaders.append(
torch.utils.data.DataLoader(
dataset=dataset,
batch_size=hypers_training["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)
)
train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True)

# Create dataloader for the validation datasets:
validation_dataloaders = []
for dataset in validation_datasets:
validation_dataloaders.append(
torch.utils.data.DataLoader(
dataset=dataset,
batch_size=hypers_training["batch_size"],
shuffle=False,
collate_fn=collate_fn,
)
)
validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False)

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = _get_outputs_dict(train_datasets)
for output_name in outputs_dict.keys():
if output_name not in model_capabilities.outputs:
raise ValueError(
f"Output {output_name} is not in the model's capabilities."
)

# Create a loss weight dict:
loss_weights_dict = {}
for output_name, value_or_gradient_list in outputs_dict.items():
loss_weights_dict[output_name] = {
value_or_gradient: 1.0 for value_or_gradient in value_or_gradient_list
}

# Create a loss function:
loss_fn = TensorMapDictLoss(loss_weights_dict)

# Create a loss function:
loss_fn = TensorMapDictLoss(
Expand All @@ -66,20 +117,75 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
model.parameters(), lr=hypers_training["learning_rate"]
)

# counters for early stopping:
best_validation_loss = float("inf")
epochs_without_improvement = 0

# Train the model:
for epoch in range(hypers_training["num_epochs"]):
if epoch % hypers_training["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
model,
Path(output_dir) / f"model_{epoch}.pt",
)
train_loss = 0.0
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
loss = compute_model_loss(loss_fn, model, structures, targets)
train_loss += loss.item()
loss.backward()
optimizer.step()

validation_loss = 0.0
for batch in validation_dataloader:
structures, targets = batch
# TODO: specify that the model is not training here to save some autograd
loss = compute_model_loss(loss_fn, model, structures, targets)
validation_loss += loss.item()

if epoch % hypers_training["log_interval"] == 0:
logger.info(
f"Epoch {epoch}, train loss: {train_loss:.4f}, "
f"validation loss: {validation_loss:.4f}"
)

if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
model,
Path(output_dir) / f"model_{epoch}.pt",
)

# early stopping criterion:
if validation_loss < best_validation_loss:
best_validation_loss = validation_loss
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epochs_without_improvement >= 50:
logger.info(
f"Early stopping criterion reached after {epoch} "
"epochs without improvement."
)
break

return model


def _get_outputs_dict(datasets: List[Dataset]):
"""
This is a helper function that extracts all the possible outputs and their gradients
from a list of datasets.

:param datasets: A list of datasets.

:returns: A dictionary mapping output names to a list of "values" (always)
and possible gradients.
"""

outputs_dict = {}
for dataset in datasets:
sample_batch = next(iter(dataset))
targets = sample_batch[1] # this is a dictionary of TensorMaps
for target_name, target_tmap in targets.items():
if target_name not in outputs_dict:
outputs_dict[target_name] = [
"values"
] + target_tmap.block().gradients_list()

return outputs_dict
Loading