Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option for continue training from checkpoint #49

Merged
merged 8 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
default="model.pt",
help="Path to save the final model (default: %(default)s).",
)
parser.add_argument(
"-c",
"--continue",
dest="continue_from",
type=str,
required=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe adding a default=None, here changes it from a string.

help="File to continue training from.",
)
parser.add_argument(
"-y",
"--hydra",
Expand All @@ -80,6 +88,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
def train_model(
options: str,
output: str = "model.pt",
continue_from: Optional[str] = None,
hydra_parameters: Optional[List[str]] = None,
) -> None:
"""
Expand Down Expand Up @@ -125,6 +134,7 @@ def train_model(
argv.append(f"--config-dir={options_new.parent}")
argv.append(f"--config-name={options_new.name}")
argv.append(f"+output_path={output}")
argv.append(f"+continue_from={continue_from}")

if hydra_parameters is not None:
argv += hydra_parameters
Expand Down Expand Up @@ -249,7 +259,7 @@ def _train_model_hydra(options: DictConfig) -> None:
for key, value in options["training_set"]["targets"].items()
}
length_unit = train_options["structures"]["length_unit"]
model_capabilities = ModelCapabilities(
requested_capabilities = ModelCapabilities(
length_unit=length_unit if length_unit is not None else "",
species=all_species,
outputs=outputs,
Expand All @@ -259,8 +269,9 @@ def _train_model_hydra(options: DictConfig) -> None:
model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
model_capabilities=model_capabilities,
requested_capabilities=requested_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
continue_from=options["continue_from"],
output_dir=output_dir,
)

Expand Down
29 changes: 27 additions & 2 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def forward(
if output_name in outputs:
atomic_energies[output_name] = apply_composition_contribution(
output_layer(hidden_features),
self.composition_weights[self.output_to_index[output_name]],
self.composition_weights[ # type: ignore
self.output_to_index[output_name]
],
)

# Sum the atomic energies coming from the BPNN to get the total energy
Expand All @@ -331,6 +333,29 @@ def set_composition_weights(
) -> None:
"""Set the composition weights for a given output."""
# all species that are not present retain their weight of zero
self.composition_weights[self.output_to_index[output_name]][
self.composition_weights[self.output_to_index[output_name]][ # type: ignore
self.all_species
] = input_composition_weights

def add_output(self, output_name: str) -> None:
"""Add a new output to the model."""
# add a new row to the composition weights tensor
Comment on lines +341 to +342
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be both in the docstring?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both?

self.composition_weights = torch.cat(
[
self.composition_weights, # type: ignore
torch.zeros(
1,
self.composition_weights.shape[1], # type: ignore
dtype=self.composition_weights.dtype, # type: ignore
device=self.composition_weights.device, # type: ignore
),
]
) # type: ignore
self.output_to_index[output_name] = len(self.output_to_index)
# add a new linear layer to the last layers
hypers_bpnn = self.hypers["bpnn"]
if hypers_bpnn["num_hidden_layers"] == 0:
n_inputs_last_layer = hypers_bpnn["input_size"]
else:
n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"]
self.last_layers[output_name] = LinearMap(self.all_species, n_inputs_last_layer)
58 changes: 41 additions & 17 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.logging import MetricLogger
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from ..utils.merge_capabilities import merge_capabilities
from ..utils.model_io import load_model, save_model
from .model import DEFAULT_HYPERS, Model


Expand All @@ -35,40 +36,63 @@
def train(
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
validation_datasets: List[Union[Dataset, torch.utils.data.Subset]],
model_capabilities: ModelCapabilities,
requested_capabilities: ModelCapabilities,
hypers: Dict = DEFAULT_HYPERS,
continue_from: str = "None",
output_dir: str = ".",
):
# Perform canonical checks on the datasets:
# Create the model:
if continue_from == "None":
model = Model(
capabilities=requested_capabilities,
hypers=hypers["model"],
)
new_capabilities = requested_capabilities
else:
model = load_model(continue_from)
filtered_new_dict = {k: v for k, v in hypers["model"].items() if k != "restart"}
filtered_old_dict = {k: v for k, v in model.hypers.items() if k != "restart"}
if filtered_new_dict != filtered_old_dict:
logger.warn(
"The hyperparameters of the model have changed since the last "
"training run. The new hyperparameters will be discarded."
)
# merge the model's capabilities with the requested capabilities
merged_capabilities, new_capabilities = merge_capabilities(
model.capabilities, requested_capabilities
)
model.capabilities = merged_capabilities
# make the new model capable of handling the new outputs
for output_name in new_capabilities.outputs.keys():
model.add_output(output_name)

model_capabilities = model.capabilities

# Perform checks on the datasets:
logger.info("Checking datasets for consistency")
check_datasets(
train_datasets,
validation_datasets,
model_capabilities,
)

# Create the model:
model = Model(
capabilities=model_capabilities,
hypers=hypers["model"],
)

# Calculate and set the composition weights for all targets:
logger.info("Calculating composition weights")
for target_name in model_capabilities.outputs.keys():
# find the dataset that contains the target:
train_dataset_with_target = None
for target_name in new_capabilities.outputs.keys():
# TODO: warn in the documentation that capabilities that are already
# present in the model won't recalculate the composition weights
# find the datasets that contain the target:
train_datasets_with_target = []
for dataset in train_datasets:
if target_name in get_all_targets(dataset):
train_dataset_with_target = dataset
break
if train_dataset_with_target is None:
train_datasets_with_target.append(dataset)
if len(train_datasets_with_target) == 0:
raise ValueError(
f"Target {target_name} in the model's capabilities is not "
f"Target {target_name} in the model's new capabilities is not "
"present in any of the training datasets."
)
composition_weights = calculate_composition_weights(
train_dataset_with_target, target_name
train_datasets_with_target, target_name
)
model.set_composition_weights(target_name, composition_weights)

Expand Down
12 changes: 9 additions & 3 deletions src/metatensor/models/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def calculate_composition_weights(
dataset: torch.utils.data.Dataset, property: str
datasets: List[torch.utils.data.Dataset], property: str
) -> torch.Tensor:
"""Calculate the composition weights for a dataset.
For now, it assumes per-structure properties.
Expand All @@ -23,12 +23,18 @@ def calculate_composition_weights(
"""

# Get the target for each structure in the dataset
targets = torch.stack([sample[1][property].block().values for sample in dataset])
targets = torch.stack(
[
sample[1][property].block().values
for dataset in datasets
for sample in dataset
]
)

# Get the composition for each structure in the dataset
composition_calculator = rascaline.torch.AtomicComposition(per_structure=True)
composition_features = composition_calculator.compute(
[sample[0] for sample in dataset]
[sample[0] for dataset in datasets for sample in dataset]
)
composition_features = composition_features.keys_to_properties("species_center")
composition_features = composition_features.block().values
Expand Down
61 changes: 31 additions & 30 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from typing import Dict, List, Tuple

Expand All @@ -8,6 +9,9 @@
from metatensor.torch.atomistic import ModelCapabilities, System


logger = logging.getLogger(__name__)


if os.environ.get("METATENSOR_IMPORT_FOR_SPHINX", "0") == "1":
# This is necessary to make the Sphinx documentation build
def compiled_slice(a, b):
Expand Down Expand Up @@ -177,56 +181,53 @@ def check_datasets(
:param capabilities: The model's capabilities.

:raises ValueError: If the training and validation sets are not compatible
with one another or with the model's capabilities.
with the model's capabilities.
"""

# Get all targets in the training sets:
targets = []
# Get all targets in the training and validation sets:
train_targets = []
for dataset in train_datasets:
targets += get_all_targets(dataset)
train_targets += get_all_targets(dataset)
validation_targets = []
for dataset in validation_datasets:
validation_targets += get_all_targets(dataset)

# Check that they are compatible with the model's capabilities:
for target in targets:
for target in train_targets + validation_targets:
if target not in capabilities.outputs.keys():
raise ValueError(f"The target {target} is not in the model's capabilities.")

# For now, we impose no overlap between the targets in the training sets:
if len(set(targets)) != len(targets):
raise ValueError(
"The training datasets must not have overlapping targets in SOAP-BPNN. "
"This means that one target cannot be in more than one dataset."
)

# Check that the validation sets do not have targets that are not in the
# training sets:
for dataset in validation_datasets:
for target in get_all_targets(dataset):
if target not in targets:
raise ValueError(
f"The validation dataset has a target ({target}) "
"that is not in the training datasets."
)
for target in validation_targets:
if target not in train_targets:
logger.warn(
f"The validation dataset has a target ({target}) "
"that is not in the training dataset."
)

# Get all the species in the training sets:
# Get all the species in the training and validation sets:
all_training_species = []
for dataset in train_datasets:
all_training_species += get_all_species(dataset)
all_validation_species = []
for dataset in validation_datasets:
all_validation_species += get_all_species(dataset)

# Check that they are compatible with the model's capabilities:
for species in all_training_species:
for species in all_training_species + all_validation_species:
if species not in capabilities.species:
raise ValueError(
f"The species {species} is not in the model's capabilities."
)

# Check that the validation sets do not have species that are not in the
# training sets:
for dataset in validation_datasets:
for species in get_all_species(dataset):
if species not in all_training_species:
raise ValueError(
f"The validation dataset has a species ({species}) "
"that is not in the training datasets. This could be "
"a result of a random train/validation split. You can "
"avoid this by providing a validation dataset manually."
)
for species in all_validation_species:
if species not in all_training_species:
logger.warn(
f"The validation dataset has a species ({species}) "
"that is not in the training dataset. This could be "
"a result of a random train/validation split. You can "
"avoid this by providing a validation dataset manually."
)
63 changes: 63 additions & 0 deletions src/metatensor/models/utils/merge_capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Tuple

from metatensor.torch.atomistic import ModelCapabilities


def merge_capabilities(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a test for this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my bad

old_capabilities: ModelCapabilities, requested_capabilities: ModelCapabilities
) -> Tuple[ModelCapabilities, ModelCapabilities]:
"""
Merge the capabilities of a model with the requested capabilities.

:param old_capabilities: The old capabilities of the model.
:param requested_capabilities: The requested capabilities.

:return: The merged capabilities and the new capabilities that
were not present in the old capabilities. The order will
be preserved.
"""
# Check that the length units are the same:
if old_capabilities.length_unit != requested_capabilities.length_unit:
raise ValueError(
"The length units of the old and new capabilities are not the same."
)

# Check that there are no new species:
for species in requested_capabilities.species:
if species not in old_capabilities.species:
raise ValueError(
f"The species {species} is not within "
"the capabilities of the loaded model."
)

# Merge the outputs:
outputs = {}
for key, value in old_capabilities.outputs.items():
outputs[key] = value
for key, value in requested_capabilities.outputs.items():
if key not in outputs:
outputs[key] = value
else:
assert (
outputs[key].unit == value.unit
), f"Output {key} has different units in the old and new capabilities."

# Find the new outputs:
new_outputs = {}
for key, value in requested_capabilities.outputs.items():
if key not in old_capabilities.outputs:
new_outputs[key] = value

merged_capabilities = ModelCapabilities(
length_unit=requested_capabilities.length_unit,
species=old_capabilities.species,
outputs=outputs,
)

new_capabilities = ModelCapabilities(
length_unit=requested_capabilities.length_unit,
species=old_capabilities.species,
outputs=new_outputs,
)

return merged_capabilities, new_capabilities
Loading
Loading