diff --git a/.readthedocs.yml b/.readthedocs.yml index 2ebedc6ba..dc12989c7 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -15,6 +15,7 @@ build: pre_build: - set -e && cd examples/ase && bash train.sh - set -e && cd examples/programmatic/llpr && bash train.sh + - set -e && cd examples/zbl && bash train.sh # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/docs/generate_examples/conf.py b/docs/generate_examples/conf.py index 79715a5ea..40448f3e0 100644 --- a/docs/generate_examples/conf.py +++ b/docs/generate_examples/conf.py @@ -13,8 +13,16 @@ sphinx_gallery_conf = { "filename_pattern": "/*", "copyfile_regex": r".*\.(pt|sh|xyz|yaml)", - "examples_dirs": [os.path.join(ROOT, "examples", "ase"), os.path.join(ROOT, "examples", "programmatic", "llpr")], - "gallery_dirs": [os.path.join(ROOT, "docs", "src", "examples", "ase"), os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr")], + "examples_dirs": [ + os.path.join(ROOT, "examples", "ase"), + os.path.join(ROOT, "examples", "programmatic", "llpr"), + os.path.join(ROOT, "examples", "zbl") + ], + "gallery_dirs": [ + os.path.join(ROOT, "docs", "src", "examples", "ase"), + os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr"), + os.path.join(ROOT, "docs", "src", "examples", "zbl") + ], "min_reported_time": 5, "matplotlib_animations": True, } diff --git a/docs/src/conf.py b/docs/src/conf.py index c9de1502a..d367c9eb8 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,6 +10,7 @@ # to include the documentation os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1" +os.environ["PYTORCH_JIT"] = "0" import metatrain # noqa: E402 @@ -53,9 +54,11 @@ def generate_examples(): # METATENSOR_IMPORT_FOR_SPHINX=1). So instead we run it inside a small script, and # include the corresponding output later. del os.environ["METATENSOR_IMPORT_FOR_SPHINX"] + del os.environ["PYTORCH_JIT"] script = os.path.join(ROOT, "docs", "generate_examples", "generate-examples.py") subprocess.run([sys.executable, script]) os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" + os.environ["PYTORCH_JIT"] = "0" def setup(app): diff --git a/docs/src/dev-docs/utils/additive/composition.rst b/docs/src/dev-docs/utils/additive/composition.rst new file mode 100644 index 000000000..4499f0c97 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/composition.rst @@ -0,0 +1,7 @@ +Composition model +################# + +.. automodule:: metatrain.utils.additive.composition + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/additive/index.rst b/docs/src/dev-docs/utils/additive/index.rst new file mode 100644 index 000000000..f2aed2d26 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/index.rst @@ -0,0 +1,12 @@ +Additive models +=============== + +API for handling additive models in ``metatrain``. These are models that +can be added to one or more architectures. + +.. toctree:: + :maxdepth: 1 + + remove_additive + composition + zbl diff --git a/docs/src/dev-docs/utils/additive/remove_additive.rst b/docs/src/dev-docs/utils/additive/remove_additive.rst new file mode 100644 index 000000000..6a115a471 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/remove_additive.rst @@ -0,0 +1,7 @@ +Removing additive contributions +############################### + +.. automodule:: metatrain.utils.additive.remove + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/additive/zbl.rst b/docs/src/dev-docs/utils/additive/zbl.rst new file mode 100644 index 000000000..ab0248bde --- /dev/null +++ b/docs/src/dev-docs/utils/additive/zbl.rst @@ -0,0 +1,7 @@ +ZBL short-range potential +######################### + +.. automodule:: metatrain.utils.additive.zbl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/composition.rst b/docs/src/dev-docs/utils/composition.rst deleted file mode 100644 index 0a6cb2a34..000000000 --- a/docs/src/dev-docs/utils/composition.rst +++ /dev/null @@ -1,7 +0,0 @@ -Composition -########### - -.. automodule:: metatrain.utils.composition - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index 3a2d12d07..0f5cf8abc 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -6,9 +6,9 @@ This is the API for the ``utils`` module of ``metatrain``. .. toctree:: :maxdepth: 1 + additive/index data/index architectures - composition devices dtype errors @@ -24,4 +24,5 @@ This is the API for the ``utils`` module of ``metatrain``. omegaconf output_gradient per_atom + transfer units diff --git a/docs/src/dev-docs/utils/transfer.rst b/docs/src/dev-docs/utils/transfer.rst new file mode 100644 index 000000000..ba9710296 --- /dev/null +++ b/docs/src/dev-docs/utils/transfer.rst @@ -0,0 +1,7 @@ +Data type and device transfers +############################## + +.. automodule:: metatrain.utils.transfer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/tutorials/index.rst b/docs/src/tutorials/index.rst index 3414fda6c..c8b0f5e4c 100644 --- a/docs/src/tutorials/index.rst +++ b/docs/src/tutorials/index.rst @@ -11,3 +11,4 @@ This sections includes some more advanced tutorials on the usage of the ../examples/ase/run_ase ../examples/programmatic/llpr/llpr + ../examples/zbl/dimers diff --git a/examples/ase/run_ase.py b/examples/ase/run_ase.py index 18a172d34..030f40dca 100644 --- a/examples/ase/run_ase.py +++ b/examples/ase/run_ase.py @@ -42,11 +42,6 @@ # %% # -# .. note:: -# We have to import ``rascaline.torch`` even though it is not used explicitly in this -# tutorial. The SOAP-BPNN model contains compiled extensions and therefore the import -# is required. -# # Setting up the simulation # ------------------------- # diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index 8db135c86..10857aaaf 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -48,7 +48,10 @@ # how to create a Dataset object from them. from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402 -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists # noqa: E402 +from metatrain.utils.neighbor_lists import ( # noqa: E402 + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) qm9_systems = read_systems("qm9_reduced_100.xyz") @@ -67,7 +70,7 @@ } targets, _ = read_targets(target_config) -requested_neighbor_lists = model.requested_neighbor_lists() +requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems diff --git a/examples/zbl/README.rst b/examples/zbl/README.rst new file mode 100644 index 000000000..50f7bdb41 --- /dev/null +++ b/examples/zbl/README.rst @@ -0,0 +1,2 @@ +Running molecular dynamics with ASE +=================================== diff --git a/examples/zbl/dimers.py b/examples/zbl/dimers.py new file mode 100644 index 000000000..04069a5a5 --- /dev/null +++ b/examples/zbl/dimers.py @@ -0,0 +1,148 @@ +""" +Training a model with ZBL corrections +===================================== + +This tutorial demonstrates how to train a model with ZBL corrections. + +The training set for this example consists of a +subset of the ethanol moleculs from the `rMD17 dataset +`_. + +The models are trained using the following training options, respectively: + +.. literalinclude:: options_no_zbl.yaml + :language: yaml + +.. literalinclude:: options_zbl.yaml + :language: yaml + +As you can see, they are identical, except for the ``zbl`` key in the +``model`` section. +You can train the same models yourself with + +.. literalinclude:: train.sh + :language: bash + +A detailed step-by-step introduction on how to train a model is provided in +the :ref:`label_basic_usage` tutorial. +""" + +# %% +# +# First, we start by importing the necessary libraries, including the integration of ASE +# calculators for metatensor atomistic models. + +import ase +import matplotlib.pyplot as plt +import numpy as np +import torch +from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator + + +# %% +# +# Setting up the dimers +# --------------------- +# +# We set up a series of dimers with different atom pairs and distances. We will +# calculate the energies of these dimers using the models trained with and without ZBL +# corrections. + +distances = np.linspace(0.5, 6.0, 200) +pairs = {} +for pair in [("H", "H"), ("H", "C"), ("C", "C"), ("C", "O"), ("O", "O"), ("H", "O")]: + structures = [] + for distance in distances: + atoms = ase.Atoms( + symbols=[pair[0], pair[1]], + positions=[[0, 0, 0], [0, 0, distance]], + ) + structures.append(atoms) + pairs[pair] = structures + +# %% +# +# We now load the two exported models, one with and one without ZBL corrections + +calc_no_zbl = MetatensorCalculator( + "model_no_zbl.pt", extensions_directory="extensions/" +) +calc_zbl = MetatensorCalculator("model_zbl.pt", extensions_directory="extensions/") + + +# %% +# +# Calculate and plot energies without ZBL +# --------------------------------------- +# +# We calculate the energies of the dimer curves for each pair of atoms and +# plot the results, using the non-ZBL-corrected model. + +for pair, structures_for_pair in pairs.items(): + energies = [] + for atoms in structures_for_pair: + atoms.set_calculator(calc_no_zbl) + with torch.jit.optimized_execution(False): + energies.append(atoms.get_potential_energy()) + energies = np.array(energies) - energies[-1] + plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}") +plt.title("Dimer curves - no ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() + +# %% +# +# Calculate and plot energies from the ZBL-corrected model +# -------------------------------------------------------- +# +# We repeat the same procedure as above, but this time with the ZBL-corrected model. + +for pair, structures_for_pair in pairs.items(): + energies = [] + for atoms in structures_for_pair: + atoms.set_calculator(calc_zbl) + with torch.jit.optimized_execution(False): + energies.append(atoms.get_potential_energy()) + energies = np.array(energies) - energies[-1] + plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}") +plt.title("Dimer curves - with ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() + +# %% +# +# It can be seen that all the dimer curves include a strong repulsion +# at short distances, which is due to the ZBL contribution. Even the H-H dimer, +# whose ZBL correction is very weak due to the small covalent radii of hydrogen, +# would show a strong repulsion closer to the origin (here, we only plotted +# starting from a distance of 0.5 Å). Let's zoom in on the H-H dimer to see +# this effect more clearly. + +new_distances = np.linspace(0.1, 2.0, 200) + +structures = [] +for distance in new_distances: + atoms = ase.Atoms( + symbols=["H", "H"], + positions=[[0, 0, 0], [0, 0, distance]], + ) + structures.append(atoms) + +for atoms in structures: + atoms.set_calculator(calc_zbl) +with torch.jit.optimized_execution(False): + energies = [atoms.get_potential_energy() for atoms in structures] +energies = np.array(energies) - energies[-1] +plt.plot(new_distances, energies, label="H-H") +plt.title("Dimer curve - H-H with ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/examples/zbl/ethanol_reduced_100.xyz b/examples/zbl/ethanol_reduced_100.xyz new file mode 120000 index 000000000..f01afa4c6 --- /dev/null +++ b/examples/zbl/ethanol_reduced_100.xyz @@ -0,0 +1 @@ +../ase/ethanol_reduced_100.xyz \ No newline at end of file diff --git a/examples/zbl/options_no_zbl.yaml b/examples/zbl/options_no_zbl.yaml new file mode 100644 index 000000000..e53e218ba --- /dev/null +++ b/examples/zbl/options_no_zbl.yaml @@ -0,0 +1,21 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + model: + zbl: false + training: + num_epochs: 10 + +# training set section +training_set: + systems: + read_from: ethanol_reduced_100.xyz + length_unit: angstrom + targets: + energy: + key: "energy" + unit: "eV" # very important to run simulations + +validation_set: 0.1 +test_set: 0.0 diff --git a/examples/zbl/options_zbl.yaml b/examples/zbl/options_zbl.yaml new file mode 100644 index 000000000..56fe80642 --- /dev/null +++ b/examples/zbl/options_zbl.yaml @@ -0,0 +1,21 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + model: + zbl: true + training: + num_epochs: 10 + +# training set section +training_set: + systems: + read_from: ethanol_reduced_100.xyz + length_unit: angstrom + targets: + energy: + key: "energy" + unit: "eV" # very important to run simulations + +validation_set: 0.1 +test_set: 0.0 diff --git a/examples/zbl/train.sh b/examples/zbl/train.sh new file mode 100755 index 000000000..03b6baab2 --- /dev/null +++ b/examples/zbl/train.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +mtt train options_no_zbl.yaml -o model_no_zbl.pt +mtt train options_zbl.yaml -o model_zbl.pt diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 93adb1627..c796b83c1 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -1,10 +1,12 @@ import argparse import itertools import logging +import time from pathlib import Path from typing import Dict, List, Optional, Union import metatensor.torch +import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import MetatensorAtomisticModel @@ -23,7 +25,10 @@ from ..utils.evaluate_model import evaluate_model from ..utils.logging import MetricLogger from ..utils.metrics import RMSEAccumulator -from ..utils.neighbor_lists import get_system_with_neighbor_lists +from ..utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ..utils.omegaconf import expand_dataset_config from ..utils.per_atom import average_by_num_atoms from .formatter import CustomHelpFormatter @@ -161,18 +166,21 @@ def _eval_targets( """Evaluates an exported model on a dataset and prints the RMSEs for each target. Optionally, it also returns the predictions of the model. + The total and per-atom timings for the evaluation are also printed. + Wraps around metatrain.cli.evaluate_model. """ if len(dataset) == 0: logger.info("This dataset is empty. No evaluation will be performed.") + return None # Attach neighbor lists to the systems: # TODO: these might already be present... find a way to avoid recomputing # if already present (e.g. if this function is called after training) for sample in dataset: system = sample["system"] - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model)) # Infer the device and dtype from the model model_tensor = next(itertools.chain(model.parameters(), model.buffers())) @@ -194,6 +202,10 @@ def _eval_targets( if return_predictions: all_predictions = [] + # Set up timings: + total_time = 0.0 + timings_per_atom = [] + # Evaluate the model for batch in dataloader: systems, batch_targets = batch @@ -202,6 +214,9 @@ def _eval_targets( key: value.to(dtype=dtype, device=device) for key, value in batch_targets.items() } + + start_time = time.time() + batch_predictions = evaluate_model( model, systems, @@ -209,6 +224,11 @@ def _eval_targets( is_training=False, check_consistency=check_consistency, ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.time() + batch_predictions = average_by_num_atoms( batch_predictions, systems, per_structure_keys=[] ) @@ -219,6 +239,10 @@ def _eval_targets( if return_predictions: all_predictions.append(batch_predictions) + time_taken = end_time - start_time + total_time += time_taken + timings_per_atom.append(time_taken / sum(len(system) for system in systems)) + # Finalize the RMSEs rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"]) # print the RMSEs with MetricLogger @@ -229,6 +253,16 @@ def _eval_targets( ) metric_logger.log(rmse_values) + # Log timings + timings_per_atom = np.array(timings_per_atom) + mean_per_atom = np.mean(timings_per_atom) + std_per_atom = np.std(timings_per_atom) + logger.info( + f"evaluation time: {total_time:.2f} s " + f"[{1000.0*mean_per_atom:.2f} ± " + f"{1000.0*std_per_atom:.2f} ms per atom]" + ) + if return_predictions: # concatenate the TensorMaps all_predictions_joined = _concatenate_tensormaps(all_predictions) diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index d63f11f3e..378d93145 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -1,12 +1,11 @@ import argparse -import importlib import logging from pathlib import Path from typing import Any, Union import torch -from ..utils.architectures import check_architecture_name, find_all_architectures +from ..utils.architectures import find_all_architectures, import_architecture from ..utils.export import is_exported from ..utils.io import check_file_extension from .formatter import CustomHelpFormatter @@ -57,8 +56,7 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" architecture_name = args.__dict__.pop("architecture_name") - check_architecture_name(architecture_name) - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(architecture_name) args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path")) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index ffc880f73..5523b6ef4 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -1,5 +1,4 @@ import argparse -import importlib import itertools import json import logging @@ -14,7 +13,11 @@ from omegaconf import DictConfig, OmegaConf from .. import PACKAGE_ROOT -from ..utils.architectures import check_architecture_options, get_default_hypers +from ..utils.architectures import ( + check_architecture_options, + get_default_hypers, + import_architecture, +) from ..utils.data import ( DatasetInfo, TargetInfoDict, @@ -135,7 +138,7 @@ def train_model( check_architecture_options( name=architecture_name, options=OmegaConf.to_container(options["architecture"]) ) - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(architecture_name) logger.info(f"Running training for {architecture_name!r} architecture") diff --git a/src/metatrain/experimental/alchemical_model/default-hypers.yaml b/src/metatrain/experimental/alchemical_model/default-hypers.yaml index d41f53afb..4d3f91eb2 100644 --- a/src/metatrain/experimental/alchemical_model/default-hypers.yaml +++ b/src/metatrain/experimental/alchemical_model/default-hypers.yaml @@ -13,6 +13,7 @@ model: bpnn: hidden_sizes: [32, 32] output_size: 1 + zbl: false training: batch_size: 8 diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 52b9af5fb..8ffbd8dc7 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union +import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ( @@ -12,6 +13,7 @@ ) from torch_alchemical.models import AlchemicalModel as AlchemicalModelUpstream +from ...utils.additive import ZBL from ...utils.data.dataset import DatasetInfo from ...utils.dtype import dtype_to_str from ...utils.export import export @@ -54,6 +56,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: **self.hypers["bpnn"], ) + additive_models = [] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) + self.cutoff = self.hypers["soap"]["cutoff"] self.is_restarted = False @@ -123,6 +130,18 @@ def forward( keys=keys, blocks=[block], ) + + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + total_energies[output_name] = metatensor.torch.add( + total_energies[output_name], + additive_contributions[output_name], + ) + return total_energies @classmethod @@ -145,10 +164,21 @@ def export(self) -> MetatensorAtomisticModel: if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {dtype} for AlchemicalModel") + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=self.__supported_devices__, dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/alchemical_model/schema-hypers.json b/src/metatrain/experimental/alchemical_model/schema-hypers.json index b5901c318..4e9e141ed 100644 --- a/src/metatrain/experimental/alchemical_model/schema-hypers.json +++ b/src/metatrain/experimental/alchemical_model/schema-hypers.json @@ -53,6 +53,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 3be002445..891983693 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index b3c42d81f..7ee3331af 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -3,7 +3,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -25,7 +28,8 @@ def test_prediction_subset_elements(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index f64925848..9d9a84dd9 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -6,7 +6,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, MODEL_HYPERS @@ -24,13 +27,15 @@ def test_rotational_invariance(): system = ase.io.read(DATASET_PATH) original_system = copy.deepcopy(system) original_system = systems_to_torch(original_system) + requested_neighbor_lists = get_requested_neighbor_lists(model) original_system = get_system_with_neighbor_lists( - original_system, model.requested_neighbor_lists() + original_system, requested_neighbor_lists ) system.rotate(48, "y") system = systems_to_torch(system) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 4dbc6ed0b..648c91c18 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -14,7 +14,10 @@ read_targets, ) from metatrain.utils.data.dataset import TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -38,8 +41,9 @@ def test_regression_init(): # Predict on the first five systems systems = read_systems(DATASET_PATH)[:5] + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] @@ -101,8 +105,9 @@ def test_regression_train(): ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index dddbc9839..c9b9b5f60 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -5,6 +5,7 @@ import torch from metatensor.learn.data import DataLoader +from ...utils.additive import remove_additive from ...utils.data import ( CombinedDataLoader, Dataset, @@ -19,8 +20,12 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator -from ...utils.neighbor_lists import get_system_with_neighbor_lists +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms +from ...utils.transfer import systems_and_targets_to_dtype_and_device from . import AlchemicalModel from .utils.composition import calculate_composition_weights from .utils.normalize import ( @@ -67,7 +72,7 @@ def train( # Calculating the neighbor lists for the training and validation datasets: logger.info("Calculating neighbor lists for the datasets") - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] @@ -218,11 +223,13 @@ def train( systems, targets = batch assert len(systems[0].known_neighbor_lists()) > 0 - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, model.dataset_info.targets + ) predictions = evaluate_model( model, systems, @@ -254,11 +261,13 @@ def train( for batch in val_dataloader: systems, targets = batch assert len(systems[0].known_neighbor_lists()) > 0 - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, model.dataset_info.targets + ) predictions = evaluate_model( model, systems, diff --git a/src/metatrain/experimental/gap/default-hypers.yaml b/src/metatrain/experimental/gap/default-hypers.yaml index d3b911e83..2c7f192fe 100644 --- a/src/metatrain/experimental/gap/default-hypers.yaml +++ b/src/metatrain/experimental/gap/default-hypers.yaml @@ -17,10 +17,10 @@ model: rate: 1.0 scale: 2.0 exponent: 7.0 - krr: degree: 2 num_sparse_points: 500 + zbl: false training: regularizer: 0.001 diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index 1ce0a9cb5..833325d11 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -21,7 +21,7 @@ from metatrain.utils.data.dataset import DatasetInfo -from ...utils.composition import CompositionModel +from ...utils.additive import ZBL, CompositionModel from ...utils.export import export @@ -95,7 +95,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self._soap_torch_calculator = rascaline.torch.SoapPowerSpectrum( **model_hypers["soap"] ) - self._soap_calculator = rascaline.SoapPowerSpectrum(**model_hypers["soap"]) kernel_kwargs = { "degree": model_hypers["krr"]["degree"], @@ -128,10 +127,16 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: ) self._species_labels: TorchLabels = TorchLabels.empty("_") - self.composition_model = CompositionModel( + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( model_hypers={}, dataset_info=dataset_info, ) + additive_models = [composition_model] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) def restart(self, dataset_info: DatasetInfo) -> "GAP": raise ValueError("GAP does not allow restarting training") @@ -209,24 +214,34 @@ def forward( energies = self._subset_of_regressors_torch(soap_features) return_dict = {output_key: energies} - # apply composition model - composition_energies = self.composition_model( - systems, {output_key: ModelOutput("energy", per_atom=True)}, selected_atoms - ) - composition_energies[output_key] = metatensor.torch.sum_over_samples( - composition_energies[output_key], "atom" - ) - return_dict[output_key] = metatensor.torch.add( - return_dict[output_key], composition_energies[output_key] - ) + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + for name in return_dict: + if name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + return_dict[name] = metatensor.torch.add( + return_dict[name], + additive_contributions[name], + ) return return_dict def export(self) -> MetatensorAtomisticModel: + + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=sorted(self.dataset_info.atomic_types), - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=["cuda", "cpu"], dtype="float64", diff --git a/src/metatrain/experimental/gap/schema-hypers.json b/src/metatrain/experimental/gap/schema-hypers.json index a756bc590..e793f0315 100644 --- a/src/metatrain/experimental/gap/schema-hypers.json +++ b/src/metatrain/experimental/gap/schema-hypers.json @@ -87,6 +87,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 24e5c03b3..b0359bd80 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -62,7 +62,7 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 30 target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol") + energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) ) dataset_info = DatasetInfo( diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index e2a2ee72c..81212353c 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -74,6 +74,7 @@ def test_regression_train_and_invariance(): val_datasets=[dataset], checkpoint_dir=".", ) + gap.eval() # Predict on the first five systems output = gap(systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]}) @@ -138,7 +139,7 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 900 target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol") + energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) ) dataset_info = DatasetInfo( @@ -155,6 +156,7 @@ def test_ethanol_regression_train_and_invariance(): val_datasets=[dataset], checkpoint_dir=".", ) + gap.eval() # Predict on the first five systems output = gap(systems[:5], {"energy": gap.outputs["energy"]}) diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index ffeb40951..5daef8369 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -9,8 +9,12 @@ from metatrain.utils.data import Dataset -from ...utils.composition import remove_composition +from ...utils.additive import remove_additive from ...utils.data import check_datasets +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import GAP from .model import torch_tensor_map_to_core @@ -52,7 +56,8 @@ def train( # Calculate and set the composition weights: logger.info("Calculating composition weights") - model.composition_model.train_model(train_datasets) + # model.additive_models[0] is the composition model + model.additive_models[0].train_model(train_datasets) logger.info("Setting up data loaders") if len(train_datasets[0][0][output_name].keys) > 1: @@ -69,11 +74,25 @@ def train( model._keys = train_y.keys train_structures = [sample["system"] for sample in train_dataset] - logger.info("Subtracting composition energies") - # this acts in-place on train_y - remove_composition( - train_structures, {target_name: train_y}, model.composition_model - ) + logger.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + for i in range(len(dataset)): + system = dataset[i]["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + _ = get_system_with_neighbor_lists(system, requested_neighbor_lists) + + logger.info("Subtracting composition energies") # and potentially ZBL + train_targets = {target_name: train_y} + for additive_model in model.additive_models: + train_targets = remove_additive( + train_structures, + train_targets, + additive_model, + model.dataset_info.targets, + ) + train_y = train_targets[target_name] logger.info("Calculating SOAP features") if len(train_y[0].gradients_list()) > 0: diff --git a/src/metatrain/experimental/pet/default-hypers.yaml b/src/metatrain/experimental/pet/default-hypers.yaml index 87191ab53..ad6befbbb 100644 --- a/src/metatrain/experimental/pet/default-hypers.yaml +++ b/src/metatrain/experimental/pet/default-hypers.yaml @@ -32,6 +32,7 @@ model: N_TARGETS: 1 TARGET_INDEX_KEY: target_index RESIDUAL_FACTOR: 0.5 + USE_ZBL: False training: INITIAL_LR: 1e-4 @@ -54,4 +55,5 @@ training: DO_GRADIENT_CLIPPING: False GRADIENT_CLIPPING_MAX_NORM: null # must be overwritten if DO_GRADIENT_CLIPPING is True USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS - ENERGIES_LOSS: per_structure # per_structure or per_atom \ No newline at end of file + ENERGIES_LOSS: per_structure # per_structure or per_atom + CHECKPOINT_INTERVAL: 100 \ No newline at end of file diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 0bb67b19c..bf567dee1 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -18,6 +18,7 @@ from metatrain.utils.data import DatasetInfo +from ...utils.additive import ZBL from ...utils.dtype import dtype_to_str from ...utils.export import export from .utils import systems_to_batch_dict @@ -48,6 +49,13 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.pet = None self.checkpoint_path: Optional[str] = None + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + additive_models = [] + if self.hypers["USE_ZBL"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) + def restart(self, dataset_info: DatasetInfo) -> "PET": if dataset_info != self.dataset_info: raise ValueError( @@ -110,6 +118,21 @@ def forward( if not outputs[output_name].per_atom: output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom") output_quantities[output_name] = output_tmap + + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + for output_name in output_quantities: + if output_name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + output_quantities[output_name] = metatensor.torch.add( + output_quantities[output_name], + additive_contributions[output_name], + ) + return output_quantities @classmethod @@ -148,6 +171,17 @@ def export(self) -> MetatensorAtomisticModel: if dtype not in self.__supported_dtypes__: raise ValueError(f"Unsupported dtype {self.dtype} for PET") + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + interaction_ranges = [self.hypers["N_GNN_LAYERS"] * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs={ self.target_name: ModelOutput( @@ -157,7 +191,7 @@ def export(self) -> MetatensorAtomisticModel: ) }, atomic_types=self.atomic_types, - interaction_range=self.cutoff, + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=["cpu", "cuda"], # and not __supported_devices__ dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/pet/schema-hypers.json b/src/metatrain/experimental/pet/schema-hypers.json index 19528705d..9b00a1927 100644 --- a/src/metatrain/experimental/pet/schema-hypers.json +++ b/src/metatrain/experimental/pet/schema-hypers.json @@ -123,6 +123,9 @@ }, "RESIDUAL_FACTOR": { "type": "number" + }, + "USE_ZBL": { + "type": "boolean" } }, "additionalProperties": false @@ -219,6 +222,9 @@ }, "BALANCED_DATA_LOADER": { "type": "boolean" + }, + "CHECKPOINT_INTERVAL": { + "type": "integer" } }, "additionalProperties": false, diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index a72eb88dd..f67a15e4c 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -13,7 +13,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -59,7 +62,8 @@ def test_to(device): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 74a47b075..ddf527603 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -20,7 +20,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.jsonschema import validate -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -74,7 +77,8 @@ def test_prediction(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -123,7 +127,8 @@ def test_per_atom_predictions_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -173,7 +178,8 @@ def test_selected_atoms_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index fdacbdc68..63c06afd9 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -1,19 +1,51 @@ +import datetime import logging import os +import pickle +import time import warnings from pathlib import Path from typing import List, Union import numpy as np import torch -from metatensor.learn.data import DataLoader -from pet.hypers import Hypers -from pet.pet import PET, SelfContributionsWrapper -from pet.train_model import fit_pet +from pet.analysis import adapt_hypers +from pet.data_preparation import ( + get_all_species, + get_corrected_energies, + get_forces, + get_pyg_graphs, + get_self_contributions, + update_pyg_graphs, +) +from pet.hypers import Hypers, save_hypers +from pet.pet import ( + PET, + FlagsWrapper, + PETMLIPWrapper, + PETUtilityWrapper, + SelfContributionsWrapper, +) +from pet.utilities import ( + FullLogger, + ModelKeeper, + dtype2string, + get_calc_names, + get_data_loaders, + get_loss, + get_optimizer, + get_rmse, + get_scheduler, + load_checkpoint, + log_epoch_stats, + set_reproducibility, + string2dtype, +) +from torch_geometric.nn import DataParallel -from ...utils.data import Dataset, check_datasets, collate_fn -from ...utils.data.system_to_ase import system_to_ase +from ...utils.data import Dataset, check_datasets from . import PET as WrappedPET +from .utils import dataset_to_ase, update_hypers logger = logging.getLogger(__name__) @@ -36,7 +68,8 @@ def train( ): assert dtype in WrappedPET.__supported_dtypes__ - self.pet_dir = Path(checkpoint_dir) / "pet" + name_of_calculation = "pet" + self.pet_dir = Path(checkpoint_dir) / name_of_calculation if len(train_datasets) != 1: raise ValueError("PET only supports a single training dataset") @@ -55,78 +88,20 @@ def train( train_dataset = train_datasets[0] val_dataset = val_datasets[0] - # dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 - train_dataloader = DataLoader( - train_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_fn, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_fn, - ) - # are we fitting on only energies or energies and forces? target_name = model.target_name do_forces = ( next(iter(train_dataset))[target_name].block().has_gradient("positions") ) - # set model hypers - self.hypers["ARCHITECTURAL_HYPERS"] = model.hypers - self.hypers["ARCHITECTURAL_HYPERS"]["DTYPE"] = "float32" - - # set MLIP_SETTINGS - self.hypers["MLIP_SETTINGS"] = { - "ENERGY_KEY": "energy", - "FORCES_KEY": "forces", - "USE_ENERGIES": True, - "USE_FORCES": do_forces, - } - - # set PET utility flags - self.hypers["UTILITY_FLAGS"] = { - "CALCULATION_TYPE": None, - } - - ase_train_dataset = [] - for (system,), targets in train_dataloader: - ase_atoms = system_to_ase(system) - ase_atoms.info["energy"] = float( - targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() - ) - if do_forces: - ase_atoms.arrays["forces"] = ( - -targets[target_name] - .block() - .gradient("positions") - .values.squeeze(-1) - .detach() - .cpu() - .numpy() - ) - ase_train_dataset.append(ase_atoms) - - ase_val_dataset = [] - for (system,), targets in val_dataloader: - ase_atoms = system_to_ase(system) - ase_atoms.info["energy"] = float( - targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() - ) - if do_forces: - ase_atoms.arrays["forces"] = ( - -targets[target_name] - .block() - .gradient("positions") - .values.squeeze(-1) - .detach() - .cpu() - .numpy() - ) - ase_val_dataset.append(ase_atoms) + ase_train_dataset = dataset_to_ase( + train_dataset, model, do_forces=do_forces, target_name=target_name + ) + ase_val_dataset = dataset_to_ase( + val_dataset, model, do_forces=do_forces, target_name=target_name + ) + + self.hypers = update_hypers(self.hypers, model.hypers, do_forces) device = devices[0] # only one device, as we don't support multi-gpu for now @@ -140,16 +115,521 @@ def train( else: checkpoint_path = None - fit_pet( + ######################################## + # STARTNG THE PURE PET TRAINING SCRIPT # + ######################################## + + logging.info("Initializing PET training...") + + TIME_SCRIPT_STARTED = time.time() + value = datetime.datetime.fromtimestamp(TIME_SCRIPT_STARTED) + logging.info(f"Starting training at: {value.strftime('%Y-%m-%d %H:%M:%S')}") + logging.info("Training configuration:") + + print(f"Output directory: {checkpoint_dir}") + print(f"Training using device: {device}") + + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + hypers = Hypers(self.hypers) + dtype = string2dtype(hypers.ARCHITECTURAL_HYPERS.DTYPE) + torch.set_default_dtype(dtype) + + FITTING_SCHEME = hypers.FITTING_SCHEME + MLIP_SETTINGS = hypers.MLIP_SETTINGS + ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS + + if FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS: + raise ValueError( + "shift agnostic loss is intended only for general target training" + ) + + ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar + ARCHITECTURAL_HYPERS.TARGET_TYPE = "structural" # energy is structural property + ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = ( + "sum" # energy is a sum of atomic energies + ) + print(f"Output dimensionality: {ARCHITECTURAL_HYPERS.D_OUTPUT}") + print(f"Target type: {ARCHITECTURAL_HYPERS.TARGET_TYPE}") + print(f"Target aggregation: {ARCHITECTURAL_HYPERS.TARGET_AGGREGATION}") + + set_reproducibility( + FITTING_SCHEME.RANDOM_SEED, FITTING_SCHEME.CUDA_DETERMINISTIC + ) + + print(f"Random seed: {FITTING_SCHEME.RANDOM_SEED}") + print(f"CUDA is deterministic: {FITTING_SCHEME.CUDA_DETERMINISTIC}") + + adapt_hypers(FITTING_SCHEME, ase_train_dataset) + dataset = ase_train_dataset + ase_val_dataset + all_species = get_all_species(dataset) + + name_to_load, NAME_OF_CALCULATION = get_calc_names( + os.listdir(checkpoint_dir), name_of_calculation + ) + + os.mkdir(f"{checkpoint_dir}/{NAME_OF_CALCULATION}") + np.save(f"{checkpoint_dir}/{NAME_OF_CALCULATION}/all_species.npy", all_species) + hypers.UTILITY_FLAGS.CALCULATION_TYPE = "mlip" + save_hypers(hypers, f"{checkpoint_dir}/{NAME_OF_CALCULATION}/hypers_used.yaml") + + logging.info("Convering structures to PyG graphs...") + + train_graphs = get_pyg_graphs( ase_train_dataset, + all_species, + ARCHITECTURAL_HYPERS.R_CUT, + ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, + ARCHITECTURAL_HYPERS.USE_LONG_RANGE, + ARCHITECTURAL_HYPERS.K_CUT, + ARCHITECTURAL_HYPERS.N_TARGETS > 1, + ARCHITECTURAL_HYPERS.TARGET_INDEX_KEY, + ) + val_graphs = get_pyg_graphs( ase_val_dataset, - self.hypers, - "pet", - device, - checkpoint_dir, - checkpoint_path, + all_species, + ARCHITECTURAL_HYPERS.R_CUT, + ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, + ARCHITECTURAL_HYPERS.USE_LONG_RANGE, + ARCHITECTURAL_HYPERS.K_CUT, + ARCHITECTURAL_HYPERS.N_TARGETS > 1, + ARCHITECTURAL_HYPERS.TARGET_INDEX_KEY, ) + logging.info("Pre-processing training data...") + if MLIP_SETTINGS.USE_ENERGIES: + self_contributions = get_self_contributions( + MLIP_SETTINGS.ENERGY_KEY, ase_train_dataset, all_species + ) + np.save( + f"{checkpoint_dir}/{NAME_OF_CALCULATION}/self_contributions.npy", + self_contributions, + ) + + train_energies = get_corrected_energies( + MLIP_SETTINGS.ENERGY_KEY, + ase_train_dataset, + all_species, + self_contributions, + ) + val_energies = get_corrected_energies( + MLIP_SETTINGS.ENERGY_KEY, + ase_val_dataset, + all_species, + self_contributions, + ) + + update_pyg_graphs(train_graphs, "y", train_energies) + update_pyg_graphs(val_graphs, "y", val_energies) + + if MLIP_SETTINGS.USE_FORCES: + train_forces = get_forces(ase_train_dataset, MLIP_SETTINGS.FORCES_KEY) + val_forces = get_forces(ase_val_dataset, MLIP_SETTINGS.FORCES_KEY) + + update_pyg_graphs(train_graphs, "forces", train_forces) + update_pyg_graphs(val_graphs, "forces", val_forces) + + train_loader, val_loader = get_data_loaders( + train_graphs, val_graphs, FITTING_SCHEME + ) + + logging.info("Initializing the model...") + pet_model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)).to(device) + pet_model = PETUtilityWrapper(pet_model, FITTING_SCHEME.GLOBAL_AUG) + + pet_model = PETMLIPWrapper( + pet_model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES + ) + if FITTING_SCHEME.MULTI_GPU and torch.cuda.is_available(): + logging.info( + f"Using multi-GPU training on {torch.cuda.device_count()} GPUs" + ) + pet_model = DataParallel(FlagsWrapper(pet_model)) + pet_model = pet_model.to(torch.device("cuda:0")) + + if FITTING_SCHEME.MODEL_TO_START_WITH is not None: + logging.info(f"Loading model from: {FITTING_SCHEME.MODEL_TO_START_WITH}") + pet_model.load_state_dict(torch.load(FITTING_SCHEME.MODEL_TO_START_WITH)) + pet_model = pet_model.to(dtype=dtype) + + optim = get_optimizer(pet_model, FITTING_SCHEME) + scheduler = get_scheduler(optim, FITTING_SCHEME) + + if checkpoint_path is not None: + logging.info(f"Loading model and checkpoint from: {checkpoint_path}\n") + load_checkpoint(pet_model, optim, scheduler, checkpoint_path) + elif name_to_load is not None: + path = f"{checkpoint_dir}/{name_to_load}/checkpoint" + logging.info(f"Loading model and checkpoint from: {path}\n") + load_checkpoint( + pet_model, + optim, + scheduler, + f"{checkpoint_dir}/{name_to_load}/checkpoint", + ) + + history = [] + if MLIP_SETTINGS.USE_ENERGIES: + energies_logger = FullLogger( + FITTING_SCHEME.SUPPORT_MISSING_VALUES, + FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS, + device, + ) + + if MLIP_SETTINGS.USE_FORCES: + forces_logger = FullLogger( + FITTING_SCHEME.SUPPORT_MISSING_VALUES, + FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS, + device, + ) + + if MLIP_SETTINGS.USE_FORCES: + val_forces = torch.cat(val_forces, dim=0) + + sliding_forces_rmse = get_rmse( + val_forces.data.cpu().to(dtype=torch.float32).numpy(), 0.0 + ) + + forces_rmse_model_keeper = ModelKeeper() + forces_mae_model_keeper = ModelKeeper() + + if MLIP_SETTINGS.USE_ENERGIES: + if FITTING_SCHEME.ENERGIES_LOSS == "per_structure": + sliding_energies_rmse = get_rmse(val_energies, np.mean(val_energies)) + else: + val_n_atoms = np.array( + [len(struc.positions) for struc in ase_val_dataset] + ) + val_energies_per_atom = val_energies / val_n_atoms + sliding_energies_rmse = get_rmse( + val_energies_per_atom, np.mean(val_energies_per_atom) + ) + + energies_rmse_model_keeper = ModelKeeper() + energies_mae_model_keeper = ModelKeeper() + + if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES: + multiplication_rmse_model_keeper = ModelKeeper() + multiplication_mae_model_keeper = ModelKeeper() + + logging.info(f"Starting training for {FITTING_SCHEME.EPOCH_NUM} epochs") + if FITTING_SCHEME.EPOCHS_WARMUP > 0: + remaining_lr_scheduler_steps = ( + FITTING_SCHEME.EPOCHS_WARMUP - scheduler.last_epoch + ) + logging.info( + f"Performing {remaining_lr_scheduler_steps} epochs of LR warmup" + ) + TIME_TRAINING_STARTED = time.time() + last_elapsed_time = 0 + print("=" * 50) + for epoch in range(1, FITTING_SCHEME.EPOCH_NUM + 1): + pet_model.train(True) + for batch in train_loader: + if not FITTING_SCHEME.MULTI_GPU: + batch.to(device) + + if FITTING_SCHEME.MULTI_GPU: + pet_model.module.augmentation = True + pet_model.module.create_graph = True + predictions_energies, predictions_forces = pet_model(batch) + else: + predictions_energies, predictions_forces = pet_model( + batch, augmentation=True, create_graph=True + ) + + if FITTING_SCHEME.MULTI_GPU: + y_list = [el.y for el in batch] + batch_y = torch.tensor( + y_list, dtype=torch.get_default_dtype(), device=device + ) + + n_atoms_list = [el.n_atoms for el in batch] + batch_n_atoms = torch.tensor( + n_atoms_list, dtype=torch.get_default_dtype(), device=device + ) + # print('batch_y: ', batch_y.shape) + # print('batch_n_atoms: ', batch_n_atoms.shape) + + else: + batch_y = batch.y + batch_n_atoms = batch.n_atoms + + if FITTING_SCHEME.ENERGIES_LOSS == "per_atom": + predictions_energies = predictions_energies / batch_n_atoms + ground_truth_energies = batch_y / batch_n_atoms + else: + ground_truth_energies = batch_y + + if MLIP_SETTINGS.USE_ENERGIES: + energies_logger.train_logger.update( + predictions_energies, ground_truth_energies + ) + loss_energies = get_loss( + predictions_energies, + ground_truth_energies, + FITTING_SCHEME.SUPPORT_MISSING_VALUES, + FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS, + ) + if MLIP_SETTINGS.USE_FORCES: + + if FITTING_SCHEME.MULTI_GPU: + forces_list = [el.forces for el in batch] + batch_forces = torch.cat(forces_list, dim=0).to(device) + else: + batch_forces = batch.forces + + forces_logger.train_logger.update(predictions_forces, batch_forces) + loss_forces = get_loss( + predictions_forces, + batch_forces, + FITTING_SCHEME.SUPPORT_MISSING_VALUES, + FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS, + ) + + if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES: + loss = FITTING_SCHEME.ENERGY_WEIGHT * loss_energies / ( + sliding_energies_rmse**2 + ) + loss_forces / (sliding_forces_rmse**2) + loss.backward() + + if MLIP_SETTINGS.USE_ENERGIES and (not MLIP_SETTINGS.USE_FORCES): + loss_energies.backward() + if MLIP_SETTINGS.USE_FORCES and (not MLIP_SETTINGS.USE_ENERGIES): + loss_forces.backward() + + if FITTING_SCHEME.DO_GRADIENT_CLIPPING: + torch.nn.utils.clip_grad_norm_( + pet_model.parameters(), + max_norm=FITTING_SCHEME.GRADIENT_CLIPPING_MAX_NORM, + ) + optim.step() + optim.zero_grad() + + pet_model.train(False) + for batch in val_loader: + if not FITTING_SCHEME.MULTI_GPU: + batch.to(device) + + if FITTING_SCHEME.MULTI_GPU: + pet_model.module.augmentation = False + pet_model.module.create_graph = False + predictions_energies, predictions_forces = pet_model(batch) + else: + predictions_energies, predictions_forces = pet_model( + batch, augmentation=False, create_graph=False + ) + + if FITTING_SCHEME.MULTI_GPU: + y_list = [el.y for el in batch] + batch_y = torch.tensor( + y_list, dtype=torch.get_default_dtype(), device=device + ) + + n_atoms_list = [el.n_atoms for el in batch] + batch_n_atoms = torch.tensor( + n_atoms_list, dtype=torch.get_default_dtype(), device=device + ) + + # print('batch_y: ', batch_y.shape) + # print('batch_n_atoms: ', batch_n_atoms.shape) + else: + batch_y = batch.y + batch_n_atoms = batch.n_atoms + + if FITTING_SCHEME.ENERGIES_LOSS == "per_atom": + predictions_energies = predictions_energies / batch_n_atoms + ground_truth_energies = batch_y / batch_n_atoms + else: + ground_truth_energies = batch_y + + if MLIP_SETTINGS.USE_ENERGIES: + energies_logger.val_logger.update( + predictions_energies, ground_truth_energies + ) + if MLIP_SETTINGS.USE_FORCES: + if FITTING_SCHEME.MULTI_GPU: + forces_list = [el.forces for el in batch] + batch_forces = torch.cat(forces_list, dim=0).to(device) + else: + batch_forces = batch.forces + forces_logger.val_logger.update(predictions_forces, batch_forces) + + now = {} + if FITTING_SCHEME.ENERGIES_LOSS == "per_structure": + energies_key = "energies per structure" + else: + energies_key = "energies per atom" + + if MLIP_SETTINGS.USE_ENERGIES: + now[energies_key] = energies_logger.flush() + + if MLIP_SETTINGS.USE_FORCES: + now["forces"] = forces_logger.flush() + now["lr"] = scheduler.get_last_lr() + now["epoch"] = epoch + + now["elapsed_time"] = time.time() - TIME_TRAINING_STARTED + now["epoch_time"] = now["elapsed_time"] - last_elapsed_time + now["estimated_remaining_time"] = (now["elapsed_time"] / epoch) * ( + FITTING_SCHEME.EPOCH_NUM - epoch + ) + last_elapsed_time = now["elapsed_time"] + + if MLIP_SETTINGS.USE_ENERGIES: + sliding_energies_rmse = ( + FITTING_SCHEME.SLIDING_FACTOR * sliding_energies_rmse + + (1.0 - FITTING_SCHEME.SLIDING_FACTOR) + * now[energies_key]["val"]["rmse"] + ) + + energies_mae_model_keeper.update( + pet_model, now[energies_key]["val"]["mae"], epoch + ) + energies_rmse_model_keeper.update( + pet_model, now[energies_key]["val"]["rmse"], epoch + ) + + if MLIP_SETTINGS.USE_FORCES: + sliding_forces_rmse = ( + FITTING_SCHEME.SLIDING_FACTOR * sliding_forces_rmse + + (1.0 - FITTING_SCHEME.SLIDING_FACTOR) + * now["forces"]["val"]["rmse"] + ) + forces_mae_model_keeper.update( + pet_model, now["forces"]["val"]["mae"], epoch + ) + forces_rmse_model_keeper.update( + pet_model, now["forces"]["val"]["rmse"], epoch + ) + + if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES: + multiplication_mae_model_keeper.update( + pet_model, + now["forces"]["val"]["mae"] * now[energies_key]["val"]["mae"], + epoch, + additional_info=[ + now[energies_key]["val"]["mae"], + now["forces"]["val"]["mae"], + ], + ) + multiplication_rmse_model_keeper.update( + pet_model, + now["forces"]["val"]["rmse"] * now[energies_key]["val"]["rmse"], + epoch, + additional_info=[ + now[energies_key]["val"]["rmse"], + now["forces"]["val"]["rmse"], + ], + ) + last_lr = scheduler.get_last_lr()[0] + log_epoch_stats(epoch, FITTING_SCHEME.EPOCH_NUM, now, last_lr, energies_key) + + history.append(now) + scheduler.step() + elapsed = time.time() - TIME_SCRIPT_STARTED + if epoch > 0 and epoch % FITTING_SCHEME.CHECKPOINT_INTERVAL == 0: + checkpoint_dict = { + "model_state_dict": pet_model.state_dict(), + "optim_state_dict": optim.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "dtype_used": dtype2string(dtype), + } + torch.save( + checkpoint_dict, + f"{checkpoint_dir}/{NAME_OF_CALCULATION}/checkpoint_{epoch}", + ) + torch.save( + { + "checkpoint": checkpoint_dict, + "hypers": self.hypers, + "dataset_info": model.dataset_info, + "self_contributions": np.load( + self.pet_dir / "self_contributions.npy" # type: ignore + ), + }, + f"{checkpoint_dir}/model.ckpt_{epoch}", + ) + + if FITTING_SCHEME.MAX_TIME is not None: + if elapsed > FITTING_SCHEME.MAX_TIME: + logging.info("Reached maximum time\n") + break + logging.info("Training is finished\n") + logging.info("Saving the model and history...") + torch.save( + { + "model_state_dict": pet_model.state_dict(), + "optim_state_dict": optim.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "dtype_used": dtype2string(dtype), + }, + f"{checkpoint_dir}/{NAME_OF_CALCULATION}/checkpoint", + ) + with open(f"{checkpoint_dir}/{NAME_OF_CALCULATION}/history.pickle", "wb") as f: + pickle.dump(history, f) + + def save_model(model_name, model_keeper): + torch.save( + model_keeper.best_model.state_dict(), + f"{checkpoint_dir}/{NAME_OF_CALCULATION}/{model_name}_state_dict", + ) + + summary = "" + if MLIP_SETTINGS.USE_ENERGIES: + if FITTING_SCHEME.ENERGIES_LOSS == "per_structure": + postfix = "per structure" + if FITTING_SCHEME.ENERGIES_LOSS == "per_atom": + postfix = "per atom" + save_model("best_val_mae_energies_model", energies_mae_model_keeper) + summary += f"best val mae in energies {postfix}: " + summary += f"{energies_mae_model_keeper.best_error} " + summary += f"at epoch {energies_mae_model_keeper.best_epoch}\n" + + save_model("best_val_rmse_energies_model", energies_rmse_model_keeper) + summary += f"best val rmse in energies {postfix}: " + summary += f"{energies_rmse_model_keeper.best_error} " + summary += f"at epoch {energies_rmse_model_keeper.best_epoch}\n" + + if MLIP_SETTINGS.USE_FORCES: + save_model("best_val_mae_forces_model", forces_mae_model_keeper) + summary += f"best val mae in forces: {forces_mae_model_keeper.best_error} " + summary += f"at epoch {forces_mae_model_keeper.best_epoch}\n" + + save_model("best_val_rmse_forces_model", forces_rmse_model_keeper) + summary += ( + f"best val rmse in forces: {forces_rmse_model_keeper.best_error} " + ) + summary += f"at epoch {forces_rmse_model_keeper.best_epoch}\n" + + if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES: + save_model("best_val_mae_both_model", multiplication_mae_model_keeper) + summary += f"best both (multiplication) mae in energies {postfix}: " + summary += ( + f"{multiplication_mae_model_keeper.additional_info[0]} in forces: " + ) + summary += f"{multiplication_mae_model_keeper.additional_info[1]} " + summary += f"at epoch {multiplication_mae_model_keeper.best_epoch}\n" + + save_model("best_val_rmse_both_model", multiplication_rmse_model_keeper) + summary += f"best both (multiplication) rmse in energies {postfix}: " + summary += ( + f"{multiplication_rmse_model_keeper.additional_info[0]} in forces: " + ) + summary += ( + f"{multiplication_rmse_model_keeper.additional_info[1]} at epoch " + ) + summary += f"{multiplication_rmse_model_keeper.best_epoch}\n" + + with open(f"{checkpoint_dir}/{NAME_OF_CALCULATION}/summary.txt", "wb") as f: + f.write(summary.encode()) + logging.info(f"Total elapsed time: {time.time() - TIME_SCRIPT_STARTED}") + + ########################################## + # FINISHING THE PURE PET TRAINING SCRIPT # + ########################################## + if self.pet_checkpoint is not None: # remove the temporary file os.remove(Path(checkpoint_dir) / "checkpoint.temp") diff --git a/src/metatrain/experimental/pet/utils/__init__.py b/src/metatrain/experimental/pet/utils/__init__.py index 62a3238c4..e219ab457 100644 --- a/src/metatrain/experimental/pet/utils/__init__.py +++ b/src/metatrain/experimental/pet/utils/__init__.py @@ -1,5 +1,9 @@ from .systems_to_batch_dict import systems_to_batch_dict +from .dataset_to_ase import dataset_to_ase +from .update_hypers import update_hypers __all__ = [ "systems_to_batch_dict", + "dataset_to_ase", + "update_hypers", ] diff --git a/src/metatrain/experimental/pet/utils/dataset_to_ase.py b/src/metatrain/experimental/pet/utils/dataset_to_ase.py new file mode 100644 index 000000000..e7111a20f --- /dev/null +++ b/src/metatrain/experimental/pet/utils/dataset_to_ase.py @@ -0,0 +1,45 @@ +from metatensor.learn.data import DataLoader + +from ....utils.additive import remove_additive +from ....utils.data import collate_fn +from ....utils.data.system_to_ase import system_to_ase +from ....utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) + + +# dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 +def dataset_to_ase(dataset, model, do_forces=True, target_name="energy"): + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + ) + ase_dataset = [] + for (system,), targets in dataloader: + # remove additive model (e.g. ZBL) contributions + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) + for additive_model in model.additive_models: + targets = remove_additive( + [system], targets, additive_model, model.dataset_info.targets + ) + # transform to ase atoms + ase_atoms = system_to_ase(system) + ase_atoms.info["energy"] = float( + targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() + ) + if do_forces: + ase_atoms.arrays["forces"] = ( + -targets[target_name] + .block() + .gradient("positions") + .values.squeeze(-1) + .detach() + .cpu() + .numpy() + ) + ase_dataset.append(ase_atoms) + return ase_dataset diff --git a/src/metatrain/experimental/pet/utils/update_hypers.py b/src/metatrain/experimental/pet/utils/update_hypers.py new file mode 100644 index 000000000..70b36fb4e --- /dev/null +++ b/src/metatrain/experimental/pet/utils/update_hypers.py @@ -0,0 +1,29 @@ +from typing import Any, Dict + + +def update_hypers( + hypers: Dict[str, Any], model_hypers: Dict[str, Any], do_forces: bool = True +): + """ + Updates the hypers dictionary with the model hypers, the + MLIP_SETTINGS and UTILITY_FLAGS keys of the PET model. + """ + + # set model hypers + hypers = hypers.copy() + hypers["ARCHITECTURAL_HYPERS"] = model_hypers + hypers["ARCHITECTURAL_HYPERS"]["DTYPE"] = "float32" + + # set MLIP_SETTINGS + hypers["MLIP_SETTINGS"] = { + "ENERGY_KEY": "energy", + "FORCES_KEY": "forces", + "USE_ENERGIES": True, + "USE_FORCES": do_forces, + } + + # set PET utility flags + hypers["UTILITY_FLAGS"] = { + "CALCULATION_TYPE": None, + } + return hypers diff --git a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml index a86bc9ded..1c7fe1c66 100644 --- a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml +++ b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml @@ -15,11 +15,11 @@ model: rate: 1.0 scale: 2.0 exponent: 7.0 - bpnn: layernorm: true num_hidden_layers: 2 num_neurons_per_layer: 32 + zbl: false training: distributed: False diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 4b99f9229..556f3ef52 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -16,7 +16,7 @@ from metatrain.utils.data.dataset import DatasetInfo -from ...utils.composition import CompositionModel +from ...utils.additive import ZBL, CompositionModel from ...utils.dtype import dtype_to_str from ...utils.export import export @@ -187,10 +187,16 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: } ) - self.composition_model = CompositionModel( + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( model_hypers={}, dataset_info=dataset_info, ) + additive_models = [composition_model] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": # merge old and new dataset info @@ -274,17 +280,18 @@ def forward( ) if not self.training: - # at evaluation, we also add the composition contributions - composition_contributions = self.composition_model( - systems, outputs, selected_atoms - ) - for name in return_dict: - if name.startswith("mtt::aux::"): - continue # skip auxiliary outputs (not targets) - return_dict[name] = metatensor.torch.add( - return_dict[name], - composition_contributions[name], + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms ) + for name in return_dict: + if name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + return_dict[name] = metatensor.torch.add( + return_dict[name], + additive_contributions[name], + ) return return_dict @@ -309,14 +316,20 @@ def export(self) -> MetatensorAtomisticModel: raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn") # Make sure the model is all in the same dtype - # For example, at this point, the composition model within the SOAP-BPNN is - # still float64 + # For example, after training, the additive models could still be in + # float64 self.to(dtype) + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=self.__supported_devices__, dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/soap_bpnn/schema-hypers.json b/src/metatrain/experimental/soap_bpnn/schema-hypers.json index 570931d49..b2ca893b0 100644 --- a/src/metatrain/experimental/soap_bpnn/schema-hypers.json +++ b/src/metatrain/experimental/soap_bpnn/schema-hypers.json @@ -80,6 +80,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index cc41a360c..63242161e 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 47cc09174..cdfc3ba81 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -7,7 +7,7 @@ import torch.distributed from torch.utils.data import DataLoader, DistributedSampler -from ...utils.composition import remove_composition +from ...utils.additive import remove_additive from ...utils.data import CombinedDataLoader, Dataset, TargetInfoDict, collate_fn from ...utils.data.extract_targets import get_targets_dict from ...utils.distributed.distributed_data_parallel import DistributedDataParallel @@ -18,7 +18,12 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms +from ...utils.transfer import systems_and_targets_to_dtype_and_device from .model import SoapBpnn @@ -85,15 +90,28 @@ def train( else: logger.info(f"Training on device {device} with dtype {dtype}") + # Calculate the neighbor lists in advance (in particular, this + # needs to happen before the additive models are trained, as they + # might need them): + logger.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + for i in range(len(dataset)): + system = dataset[i]["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + _ = get_system_with_neighbor_lists(system, requested_neighbor_lists) + # Move the model to the device and dtype: model.to(device=device, dtype=dtype) - # The composition model of the SOAP-BPNN is always on CPU (to avoid OOM + # The additive models of the SOAP-BPNN are always on CPU (to avoid OOM # errors during the linear algebra training) and in float64 (to avoid # numerical errors in the composition weights, which can be very large). - model.composition_model.to(device=torch.device("cpu"), dtype=torch.float64) + for additive_model in model.additive_models: + additive_model.to(device=torch.device("cpu"), dtype=torch.float64) logger.info("Calculating composition weights") - model.composition_model.train_model( + model.additive_models[0].train_model( # this is the composition model train_datasets, self.hypers["fixed_composition_weights"] ) @@ -230,12 +248,13 @@ def train( optimizer.zero_grad() systems, targets = batch - remove_composition(systems, targets, model.composition_model) - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) predictions = evaluate_model( model, systems, @@ -270,12 +289,13 @@ def train( val_loss = 0.0 for batch in val_dataloader: systems, targets = batch - remove_composition(systems, targets, model.composition_model) - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) predictions = evaluate_model( model, systems, diff --git a/src/metatrain/utils/additive/__init__.py b/src/metatrain/utils/additive/__init__.py new file mode 100644 index 000000000..bab1a5de3 --- /dev/null +++ b/src/metatrain/utils/additive/__init__.py @@ -0,0 +1,3 @@ +from .composition import CompositionModel # noqa: F401 +from .zbl import ZBL # noqa: F401 +from .remove import remove_additive # noqa: F401 diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/additive/composition.py similarity index 88% rename from src/metatrain/utils/composition.py rename to src/metatrain/utils/additive/composition.py index 1f96e0549..7a79c660e 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -6,8 +6,8 @@ from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ModelOutput, System -from .data import Dataset, DatasetInfo, get_all_targets, get_atomic_types -from .jsonschema import validate +from ..data import Dataset, DatasetInfo, get_all_targets, get_atomic_types +from ..jsonschema import validate class CompositionModel(torch.nn.Module): @@ -42,6 +42,15 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): self.dataset_info = dataset_info self.atomic_types = sorted(dataset_info.atomic_types) + self.outputs = { + key: ModelOutput( + quantity=value.quantity, + unit=value.unit, + per_atom=True, + ) + for key, value in dataset_info.targets.items() + } + n_types = len(self.atomic_types) n_targets = len(dataset_info.targets) @@ -81,14 +90,14 @@ def train_model( raise ValueError( "Provided `datasets` contains unknown " f"atomic types {additional_types}. " - f"Known types from initilaization are {self.atomic_types}." + f"Known types from initialization are {self.atomic_types}." ) missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets))) if missing_types: warnings.warn( f"Provided `datasets` do not contain atomic types {missing_types}. " - f"Known types from initilaization are {self.atomic_types}.", + f"Known types from initialization are {self.atomic_types}.", stacklevel=2, ) @@ -192,15 +201,14 @@ def forward( ) -> Dict[str, TensorMap]: """Compute the targets for each system based on the composition weights. - :param systems: List of systems to calculate the energy per atom. + :param systems: List of systems to calculate the energy. :param outputs: Dictionary containing the model outputs. :param selected_atoms: Optional selection of atoms for which to compute the - targets. - :returns: A dictionary with the computed targets for each system. + predictions. + :returns: A dictionary with the computed predictions for each system. :raises ValueError: If no weights have been computed or if `outputs` keys contain unsupported keys. - :raises NotImplementedError: If `selected_atoms` is provided (not implemented). """ dtype = systems[0].positions.dtype device = systems[0].positions.device @@ -263,28 +271,3 @@ def forward( ) return targets_out - - -def remove_composition( - systems: List[System], - targets: Dict[str, TensorMap], - composition_model: torch.nn.Module, -): - """Remove the composition contribution from the training targets. - - The targets are changed in place. - - :param systems: List of systems. - :param targets: Dictionary containing the targets corresponding to the systems. - :param composition_model: The composition model used to calculate the composition - contribution. - """ - output_options = {} - for target_key in targets: - output_options[target_key] = ModelOutput(per_atom=False) - - composition_targets = composition_model(systems, output_options) - for target_key in targets: - targets[target_key].block().values[:] -= ( - composition_targets[target_key].block().values - ) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py new file mode 100644 index 000000000..4235af1fc --- /dev/null +++ b/src/metatrain/utils/additive/remove.py @@ -0,0 +1,75 @@ +import warnings +from typing import Dict, List, Union + +import metatensor.torch +import torch +from metatensor.torch import TensorMap +from metatensor.torch.atomistic import System + +from ..data import TargetInfo, TargetInfoDict +from ..evaluate_model import evaluate_model + + +def remove_additive( + systems: List[System], + targets: Dict[str, TensorMap], + additive_model: torch.nn.Module, + target_info_dict: Union[Dict[str, TargetInfo], TargetInfoDict], +): + """Remove an additive contribution from the training targets. + + :param systems: List of systems. + :param targets: Dictionary containing the targets corresponding to the systems. + :param additive_model: The model used to calculate the additive + contribution to be removed. + :param targets_dict: Dictionary containing information about the targets. + """ + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message=( + "GRADIENT WARNING: element 0 of tensors does not " + "require grad and does not have a grad_fn" + ), + ) + additive_contribution = evaluate_model( + additive_model, + systems, + TargetInfoDict(**{key: target_info_dict[key] for key in targets.keys()}), + is_training=False, # we don't need any gradients w.r.t. any parameters + ) + + for target_key in targets: + # make the samples the same so we can use metatensor.torch.subtract + # we also need to detach the values to avoid backpropagating through the + # subtraction + block = metatensor.torch.TensorBlock( + values=additive_contribution[target_key].block().values.detach(), + samples=targets[target_key].block().samples, + components=additive_contribution[target_key].block().components, + properties=additive_contribution[target_key].block().properties, + ) + for gradient_name, gradient in ( + additive_contribution[target_key].block().gradients() + ): + block.add_gradient( + gradient_name, + metatensor.torch.TensorBlock( + values=gradient.values.detach(), + samples=targets[target_key].block().gradient(gradient_name).samples, + components=gradient.components, + properties=gradient.properties, + ), + ) + additive_contribution[target_key] = TensorMap( + keys=targets[target_key].keys, + blocks=[ + block, + ], + ) + # subtract the additive contribution from the target + targets[target_key] = metatensor.torch.subtract( + targets[target_key], additive_contribution[target_key] + ) + + return targets diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py new file mode 100644 index 000000000..edf2ec7b2 --- /dev/null +++ b/src/metatrain/utils/additive/zbl.py @@ -0,0 +1,292 @@ +import warnings +from typing import Dict, List, Optional + +import metatensor.torch +import torch +from ase.data import covalent_radii +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System + +from ..data import DatasetInfo + + +class ZBL(torch.nn.Module): + """ + A simple model for short-range repulsive interactions. + + The implementation here is equivalent to its + `LAMMPS counterpart `_, where we set the + inner cutoff to 0 and the outer cutoff to the sum of the covalent radii of the + two atoms as tabulated in ASE. Covalent radii that are not available in ASE are + set to 0.2 Å (and a warning is issued). + + :param model_hypers: A dictionary of model hyperparameters. This contains the + "inner_cutoff" and "outer_cutoff" keys, which are the inner and outer cutoffs + for the ZBL potential. + :param dataset_info: An object containing information about the dataset, including + target quantities and atomic types. + """ + + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): + super().__init__() + + # Check capabilities + if dataset_info.length_unit != "angstrom": + raise ValueError( + "ZBL only supports angstrom units, but a " + f"{dataset_info.length_unit} unit was provided." + ) + for target in dataset_info.targets.values(): + if target.quantity != "energy": + raise ValueError( + "ZBL only supports energy-like outputs, but a " + f"{target.quantity} output was provided." + ) + if target.unit != "eV": + raise ValueError( + "ZBL only supports eV units, but a " + f"{target.unit} output was provided." + ) + + self.dataset_info = dataset_info + self.atomic_types = sorted(dataset_info.atomic_types) + + self.outputs = { + key: ModelOutput( + quantity=value.quantity, + unit=value.unit, + per_atom=True, + ) + for key, value in dataset_info.targets.items() + } + + n_types = len(self.atomic_types) + + self.output_to_output_index = { + target: i for i, target in enumerate(sorted(dataset_info.targets.keys())) + } + + self.register_buffer( + "species_to_index", + torch.full((max(self.atomic_types) + 1,), -1, dtype=torch.int), + ) + for i, t in enumerate(self.atomic_types): + self.species_to_index[t] = i + + self.register_buffer( + "covalent_radii", torch.empty((n_types,), dtype=torch.float64) + ) + for i, t in enumerate(self.atomic_types): + ase_covalent_radius = covalent_radii[t] + if ase_covalent_radius == 0.2: + # 0.2 seems to be the default value when the covalent radius + # is not known/available + warnings.warn( + f"Covalent radius for element {t} is not available in ASE. " + "Using a default value of 0.2 Å.", + stacklevel=2, + ) + self.covalent_radii[i] = ase_covalent_radius + + largest_covalent_radius = float(torch.max(self.covalent_radii)) + self.cutoff_radius = 2.0 * largest_covalent_radius + + def restart(self, dataset_info: DatasetInfo) -> "ZBL": + """Restart the model with a new dataset info. + + :param dataset_info: New dataset information to be used. + """ + return self({}, self.dataset_info.union(dataset_info)) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """Compute the energies of a system solely based on a ZBL repulsive + potential. + + :param systems: List of systems to calculate the ZBL energy. + :param outputs: Dictionary containing the model outputs. + :param selected_atoms: Optional selection of atoms for which to compute the + predictions. + :returns: A dictionary with the computed predictions for each system. + + :raises ValueError: If the `outputs` contain unsupported keys. + """ + + # Assert only one neighbor list for all systems + neighbor_lists: List[TensorBlock] = [] + for system in systems: + nl_options = self.requested_neighbor_lists()[0] + nl = system.get_neighbor_list(nl_options) + neighbor_lists.append(nl) + + # Find the elements of all i and j atoms + zi = torch.concatenate( + [ + system.types[nl.samples.column("first_atom")] + for nl, system in zip(neighbor_lists, systems) + ] + ) + zj = torch.concatenate( + [ + system.types[nl.samples.column("second_atom")] + for nl, system in zip(neighbor_lists, systems) + ] + ) + + # Find the interatomic distances + rij = torch.concatenate( + [torch.sqrt(torch.sum(nl.values**2, dim=(1, 2))) for nl in neighbor_lists] + ) + + # Find the ZBL energies + e_zbl = self.get_pairwise_zbl(zi, zj, rij) + + # Sum over edges to get node energies + indices_for_sum_list = [] + sum = 0 + for system, nl in zip(systems, neighbor_lists): + indices_for_sum_list.append(nl.samples.column("first_atom") + sum) + sum += system.positions.shape[0] + + e_zbl_nodes = torch.zeros(sum, dtype=e_zbl.dtype, device=e_zbl.device) + e_zbl_nodes.index_add_(0, torch.cat(indices_for_sum_list), e_zbl) + + device = systems[0].positions.device + + # Set the outputs as the ZBL energies + targets_out: Dict[str, TensorMap] = {} + for target_key, target in outputs.items(): + if target_key.startswith("mtt::aux::"): + continue + sample_values: List[List[int]] = [] + + for i_system, system in enumerate(systems): + sample_values += [[i_system, i_atom] for i_atom in range(len(system))] + + block = TensorBlock( + values=e_zbl_nodes.reshape(-1, 1), + samples=Labels( + ["system", "atom"], torch.tensor(sample_values, device=device) + ), + components=[], + properties=Labels( + names=["energy"], values=torch.tensor([[0]], device=device) + ), + ) + + targets_out[target_key] = TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]], device=device)), + blocks=[block], + ) + + # apply selected_atoms to the composition if needed + if selected_atoms is not None: + targets_out[target_key] = metatensor.torch.slice( + targets_out[target_key], "samples", selected_atoms + ) + + if not target.per_atom: + targets_out[target_key] = metatensor.torch.sum_over_samples( + targets_out[target_key], sample_names="atom" + ) + + return targets_out + + def get_pairwise_zbl(self, zi, zj, rij): + """ + Ziegler-Biersack-Littmark (ZBL) potential. + + Inputs are the atomic numbers (zi, zj) of the two atoms of interest + and their distance rij. + """ + # set cutoff from covalent radii of the elements + rc = ( + self.covalent_radii[self.species_to_index[zi]] + + self.covalent_radii[self.species_to_index[zj]] + ) + + r1 = 0.0 + p = 0.23 + # angstrom + a0 = 0.46850 + c = torch.tensor( + [0.02817, 0.28022, 0.50986, 0.18175], dtype=rij.dtype, device=rij.device + ) + d = torch.tensor( + [0.20162, 0.40290, 0.94229, 3.19980], dtype=rij.dtype, device=rij.device + ) + + a = a0 / (zi**p + zj**p) + + da = d.unsqueeze(-1) / a + + # e * e / (4 * pi * epsilon_0) / electron_volt / angstrom + factor = 14.399645478425668 * zi * zj + e = _e_zbl(factor, rij, c, da) # eV.angstrom + + # switching function + ec = _e_zbl(factor, rc, c, da) + dec = _dedr(factor, rc, c, da) + d2ec = _d2edr2(factor, rc, c, da) + + # coefficients are determined such that E(rc) = 0, E'(rc) = 0, and E''(rc) = 0 + A = (-3 * dec + (rc - r1) * d2ec) / ((rc - r1) ** 2) + B = (2 * dec - (rc - r1) * d2ec) / ((rc - r1) ** 3) + C = -ec + (rc - r1) * dec / 2 - (rc - r1) * (rc - r1) * d2ec / 12 + + e += A / 3 * ((rij - r1) ** 3) + B / 4 * ((rij - r1) ** 4) + C + e = e / 2.0 # divide by 2 to fix double counting of edges + + # set all contributions past the cutoff to zero + e[rij > rc] = 0.0 + + return e + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [ + NeighborListOptions( + cutoff=self.cutoff_radius, + full_list=True, + ) + ] + + +def _phi(r, c, da): + phi = torch.sum(c.unsqueeze(-1) * torch.exp(-r * da), dim=0) + return phi + + +def _dphi(r, c, da): + dphi = torch.sum(-c.unsqueeze(-1) * da * torch.exp(-r * da), dim=0) + return dphi + + +def _d2phi(r, c, da): + d2phi = torch.sum(c.unsqueeze(-1) * (da**2) * torch.exp(-r * da), dim=0) + return d2phi + + +def _e_zbl(factor, r, c, da): + phi = _phi(r, c, da) + ret = factor / r * phi + return ret + + +def _dedr(factor, r, c, da): + phi = _phi(r, c, da) + dphi = _dphi(r, c, da) + ret = factor / r * (-phi / r + dphi) + return ret + + +def _d2edr2(factor, r, c, da): + phi = _phi(r, c, da) + dphi = _dphi(r, c, da) + d2phi = _d2phi(r, c, da) + + ret = factor / r * (d2phi - 2 / r * dphi + 2 * phi / (r**2)) + return ret diff --git a/src/metatrain/utils/architectures.py b/src/metatrain/utils/architectures.py index 6d3d66966..33420f99a 100644 --- a/src/metatrain/utils/architectures.py +++ b/src/metatrain/utils/architectures.py @@ -1,4 +1,5 @@ import difflib +import importlib import json import logging from importlib.util import find_spec @@ -110,6 +111,30 @@ def get_architecture_name(path: Union[str, Path]) -> str: return name +def import_architecture(name: str): + """Import an architecture. + + :param name: name of the architecture + :raises ImportError: if the architecture dependencies are not met + """ + check_architecture_name(name) + try: + return importlib.import_module(f"metatrain.{name}") + except ImportError as err: + # consistent name with pyproject.toml's `optional-dependencies` section + name_for_deps = name + if "experimental." in name or "deprecated." in name: + name_for_deps = ".".join(name.split(".")[1:]) + + name_for_deps = name_for_deps.replace("_", "-") + + raise ImportError( + f"Trying to import '{name}' but architecture dependencies " + f"seem not be installed. \n" + f"Try to install them with `pip install .[{name_for_deps}]`" + ) from err + + def get_architecture_path(name: str) -> Path: """Return the relative path to the architeture directory. diff --git a/src/metatrain/utils/neighbor_lists.py b/src/metatrain/utils/neighbor_lists.py index b76d836f6..91f9d9641 100644 --- a/src/metatrain/utils/neighbor_lists.py +++ b/src/metatrain/utils/neighbor_lists.py @@ -15,6 +15,55 @@ from .data.system_to_ase import system_to_ase +def get_requested_neighbor_lists( + module: torch.nn.Module, +) -> List[NeighborListOptions]: + """Get the neighbor lists requested by a module and its children. + + :param module: The module for which to get the requested neighbor lists. + + :return: A list of `NeighborListOptions` objects requested by the module. + """ + requested: List[NeighborListOptions] = [] + _get_requested_neighbor_lists_in_place( + module=module, + module_name="", + requested=requested, + ) + return requested + + +def _get_requested_neighbor_lists_in_place( + module: torch.nn.Module, + module_name: str, + requested: List[NeighborListOptions], +): + # copied from + # metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py + # and just removed the length units + + if hasattr(module, "requested_neighbor_lists"): + for new_options in module.requested_neighbor_lists(): + new_options.add_requestor(module_name) + + already_requested = False + for existing in requested: + if existing == new_options: + already_requested = True + for requestor in new_options.requestors(): + existing.add_requestor(requestor) + + if not already_requested: + requested.append(new_options) + + for child_name, child in module.named_children(): + _get_requested_neighbor_lists_in_place( + module=child, + module_name=module_name + "." + child_name, + requested=requested, + ) + + def get_system_with_neighbor_lists( system: System, neighbor_lists: List[NeighborListOptions] ) -> System: diff --git a/src/metatrain/utils/omegaconf.py b/src/metatrain/utils/omegaconf.py index e156bc132..3a223df5d 100644 --- a/src/metatrain/utils/omegaconf.py +++ b/src/metatrain/utils/omegaconf.py @@ -1,4 +1,3 @@ -import importlib import json from typing import Any, Union @@ -7,13 +6,13 @@ from omegaconf.basecontainer import BaseContainer from .. import PACKAGE_ROOT, RANDOM_SEED +from .architectures import import_architecture from .devices import pick_devices from .jsonschema import validate def _get_architecture_model(conf: BaseContainer) -> Any: - architecture_name = conf["architecture"]["name"] - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(conf["architecture"]["name"]) return architecture.__model__ diff --git a/src/metatrain/utils/output_gradient.py b/src/metatrain/utils/output_gradient.py index dda6888d2..d7cc3664e 100644 --- a/src/metatrain/utils/output_gradient.py +++ b/src/metatrain/utils/output_gradient.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional import torch @@ -14,13 +15,29 @@ def compute_gradient( """ grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)] - gradient = torch.autograd.grad( - outputs=[target], - inputs=inputs, - grad_outputs=grad_outputs, - retain_graph=is_training, - create_graph=is_training, - ) + try: + gradient = torch.autograd.grad( + outputs=[target], + inputs=inputs, + grad_outputs=grad_outputs, + retain_graph=is_training, + create_graph=is_training, + ) + except RuntimeError as e: + # Torch raises an error if the target tensor does not require grad, + # but this could just mean that the target is a constant tensor, like in + # the case of composition models. In this case, we can safely ignore the error + # and we raise a warning instead. The warning can be caught and silenced in the + # appropriate places. + if ( + "element 0 of tensors does not require grad and does not have a grad_fn" + in str(e) + ): + warnings.warn(f"GRADIENT WARNING: {e}", RuntimeWarning, stacklevel=2) + gradient = [torch.zeros_like(i) for i in inputs] + else: + # Re-raise the error if it's not the one above + raise if gradient is None: raise ValueError( "Unexpected None value for computed gradient. " diff --git a/src/metatrain/utils/transfer.py b/src/metatrain/utils/transfer.py new file mode 100644 index 000000000..5aae69296 --- /dev/null +++ b/src/metatrain/utils/transfer.py @@ -0,0 +1,28 @@ +from typing import Dict, List + +import torch +from metatensor.torch import TensorMap +from metatensor.torch.atomistic import System + + +@torch.jit.script +def systems_and_targets_to_dtype_and_device( + systems: List[System], + targets: Dict[str, TensorMap], + dtype: torch.dtype, + device: torch.device, +): + """ + Transfers the systems and targets to the specified dtype and device. + + :param systems: List of systems. + :param targets: Dictionary of targets. + :param dtype: Desired data type. + :param device: Device to transfer to. + """ + + systems = [system.to(dtype=dtype, device=device) for system in systems] + targets = { + key: value.to(dtype=dtype, device=device) for key, value in targets.items() + } + return systems, targets diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 7aaa2c2b2..7eda92af0 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -68,6 +68,8 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options): log = "".join([rec.message for rec in caplog.records]) assert "energy RMSE (per atom)" in log assert "dataset with index" not in log + assert "evaluation time" in log + assert "ms per atom" in log # Test file is written predictions frames = ase.io.read("foo.xyz", ":") diff --git a/tests/utils/test_composition.py b/tests/utils/test_additive.py similarity index 82% rename from tests/utils/test_composition.py rename to tests/utils/test_additive.py index 780744664..fd2179e5e 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_additive.py @@ -7,9 +7,13 @@ from metatensor.torch.atomistic import ModelOutput, System from omegaconf import OmegaConf -from metatrain.utils.composition import CompositionModel, remove_composition +from metatrain.utils.additive import ZBL, CompositionModel, remove_additive from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -224,8 +228,8 @@ def test_composition_model_torchscript(tmpdir): ) -def test_remove_composition(): - """Tests the remove_composition function.""" +def test_remove_additive(): + """Tests the remove_additive function.""" dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" systems = read_systems(dataset_path) @@ -260,7 +264,7 @@ def test_remove_composition(): targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples") std_before = targets["mtt::U0"].block().values.std().item() - remove_composition(systems, targets, composition_model) + remove_additive(systems, targets, composition_model, target_info) std_after = targets["mtt::U0"].block().values.std().item() # In QM9 the composition contribution is very large: the standard deviation @@ -393,3 +397,81 @@ def test_composition_model_wrong_target(): ), ), ) + + +def test_zbl(): + """Test the ZBL model.""" + + dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" + + systems = read_systems(dataset_path)[:5] + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": dataset_path, + "file_format": ".xyz", + "reader": "ase", + "key": "U0", + "unit": "eV", + "forces": False, + "stress": False, + "virial": False, + } + } + _, target_info = read_targets(OmegaConf.create(conf)) + + zbl = ZBL( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 6, 7, 8], + targets=target_info, + ), + ) + + requested_neighbor_lists = get_requested_neighbor_lists(zbl) + for system in systems: + get_system_with_neighbor_lists(system, requested_neighbor_lists) + + # per_atom = True + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape != (5, 1) + + # with selected_atoms + selected_atoms = metatensor.torch.Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0]]), + ) + + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + selected_atoms=selected_atoms, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape == (1, 1) + + # per_atom = False + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system"] + assert output["mtt::U0"].block().values.shape == (5, 1) + + # check that the result is the same without batching + expected = output["mtt::U0"].block().values[3] + system = systems[3] + output = zbl( + [system], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + assert torch.allclose(output["mtt::U0"].block().values[0], expected) diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index a83db97d4..3bb788392 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path import pytest @@ -11,9 +12,14 @@ get_architecture_name, get_architecture_path, get_default_hypers, + import_architecture, ) +def is_None(*args, **kwargs) -> None: + return None + + def test_find_all_architectures(): all_arches = find_all_architectures() assert len(all_arches) == 4 @@ -116,3 +122,27 @@ def test_check_architecture_options_error_raise(): match = r"Unrecognized options \('num_epochxxx' was unexpected\)" with pytest.raises(ValidationError, match=match): check_architecture_options(name=name, options=options) + + +def test_import_architecture(): + name = "experimental.soap_bpnn" + architecture_ref = importlib.import_module(f"metatrain.{name}") + assert import_architecture(name) == architecture_ref + + +def test_import_architecture_erro(monkeypatch): + # `check_architecture_name` is called inside `import_architecture` and we have to + # disble the check to allow passing our "unknown" fancy-model below. + monkeypatch.setattr( + "metatrain.utils.architectures.check_architecture_name", is_None + ) + + name = "experimental.fancy_model" + name_for_deps = "fancy-model" + + match = ( + rf"Trying to import '{name}' but architecture dependencies seem not be " + rf"installed. \nTry to install them with `pip install .\[{name_for_deps}\]`" + ) + with pytest.raises(ImportError, match=match): + import_architecture(name) diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index 72826cd7a..e2bd81eca 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -6,7 +6,10 @@ from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS, RESOURCES_PATH @@ -45,8 +48,9 @@ def test_evaluate_model(training, exported): ) model = export(model, capabilities) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index f340b2ba0..4c97c79b1 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -10,7 +10,10 @@ from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets from metatrain.utils.llpr import LLPRUncertaintyModel from metatrain.utils.loss import TensorMapDictLoss -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import RESOURCES_PATH @@ -38,7 +41,7 @@ def test_llpr(tmpdir): }, } targets, _ = read_targets(target_config) - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py new file mode 100644 index 000000000..cceb3bf23 --- /dev/null +++ b/tests/utils/test_transfer.py @@ -0,0 +1,39 @@ +import metatensor.torch +import torch +from metatensor.torch import Labels, TensorMap +from metatensor.torch.atomistic import System + +from metatrain.utils.transfer import systems_and_targets_to_dtype_and_device + + +def test_systems_and_targets_to_dtype_and_device(): + system = System( + positions=torch.tensor([[1.0, 1.0, 1.0]]), + cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + types=torch.tensor([1]), + ) + targets = TensorMap( + keys=Labels.single(), + blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], + ) + + systems = [system] + targets = {"energy": targets} + + assert systems[0].positions.dtype == torch.float32 + assert systems[0].positions.device == torch.device("cpu") + assert systems[0].cell.dtype == torch.float32 + assert systems[0].types.device == torch.device("cpu") + assert targets["energy"].block().values.dtype == torch.float32 + assert targets["energy"].block().values.device == torch.device("cpu") + + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, torch.float64, torch.device("meta") + ) + + assert systems[0].positions.dtype == torch.float64 + assert systems[0].positions.device == torch.device("meta") + assert systems[0].cell.dtype == torch.float64 + assert systems[0].types.device == torch.device("meta") + assert targets["energy"].block().values.dtype == torch.float64 + assert targets["energy"].block().values.device == torch.device("meta") diff --git a/tox.ini b/tox.ini index a38fe4949..a8ef4bdd3 100644 --- a/tox.ini +++ b/tox.ini @@ -143,6 +143,7 @@ commands_pre = bash -c "set -e && cd {toxinidir}/examples/basic_usage && bash usage.sh" bash -c "set -e && cd {toxinidir}/examples/ase && bash train.sh" bash -c "set -e && cd {toxinidir}/examples/programmatic/llpr && bash train.sh" + bash -c "set -e && cd {toxinidir}/examples/zbl && bash train.sh" sphinx-build \ {posargs:-E} \ --builder html \