diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 84e4ce2eb..f2cadaba0 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -269,10 +269,10 @@ def eval_model( # TODO: allow the user to specify which outputs to evaluate eval_targets = {} eval_info_dict = TargetInfoDict() - gradients = {"positions"} + gradients = ["positions"] if all(not torch.all(system.cell == 0) for system in eval_systems): # only add strain if all structures have cells - gradients.add("strain") + gradients.append("strain") for key in model.capabilities().outputs.keys(): eval_info_dict[key] = TargetInfo( quantity=model.capabilities().outputs[key].quantity, diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 7167ee868..21b622a27 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -27,7 +27,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: super().__init__() self.hypers = model_hypers self.dataset_info = dataset_info - self.atomic_types = sorted(dataset_info.atomic_types) + self.atomic_types = dataset_info.atomic_types if len(dataset_info.targets) != 1: raise ValueError("The AlchemicalModel only supports a single target") diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 434cdac12..3be002445 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -18,7 +18,7 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index e5f2f34ee..b3c42d81f 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -14,7 +14,7 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index 178c69327..f64925848 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -16,7 +16,7 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index fe61729cc..fe8dec96a 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -32,7 +32,7 @@ def test_regression_init(): targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) @@ -97,7 +97,7 @@ def test_regression_train(): hypers = DEFAULT_HYPERS.copy() dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index 570671ddd..c6c186fce 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -69,7 +69,7 @@ def test_alchemical_model_inference(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types=set(unique_numbers), + atomic_types=unique_numbers, targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py index e0a8904e2..33e0b9e9f 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py @@ -11,7 +11,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) @@ -24,7 +24,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index 180d2474c..97a24186c 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -63,7 +63,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: for key, value in dataset_info.targets.items() } - self.atomic_types = sorted(dataset_info.atomic_types) + self.atomic_types = dataset_info.atomic_types self.hypers = model_hypers # creates a composition weight tensor that can be directly indexed by species, diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 884674ba2..7cc1958fb 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -64,7 +64,7 @@ def test_ethanol_regression_train_and_invariance(): ) dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) gap = GAP(hypers["model"], dataset_info) diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index 2b0800ab6..e4d2dda1a 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -29,7 +29,7 @@ def test_regression_init(): targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets ) GAP(DEFAULT_HYPERS["model"], dataset_info) @@ -61,7 +61,7 @@ def test_regression_train_and_invariance(): target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) gap = GAP(DEFAULT_HYPERS["model"], dataset_info) @@ -142,7 +142,7 @@ def test_ethanol_regression_train_and_invariance(): ) dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) gap = GAP(hypers["model"], dataset_info) diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index 4b8e1af4a..f0680fd40 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -17,7 +17,7 @@ def test_torchscript(): target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) conf = { "mtt::U0": { @@ -68,7 +68,7 @@ def test_torchscript_save(): targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets ) gap = GAP(DEFAULT_HYPERS["model"], dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index b6192b9ee..0ff1ad0e3 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -43,7 +43,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: model_hypers["TARGET_AGGREGATION"] = "sum" self.hypers = model_hypers self.cutoff = self.hypers["R_CUT"] - self.atomic_types: List[int] = sorted(dataset_info.atomic_types) + self.atomic_types: List[int] = dataset_info.atomic_types self.dataset_info = dataset_info self.pet = None self.checkpoint_path: Optional[str] = None diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index 7230f66d5..a72eb88dd 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -28,7 +28,7 @@ def test_to(device): dtype = torch.float32 # for now dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 5b1463bea..74a47b075 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -61,7 +61,7 @@ def test_prediction(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) @@ -110,7 +110,7 @@ def test_per_atom_predictions_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) @@ -160,7 +160,7 @@ def test_selected_atoms_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 38d3a563d..16f7955a9 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -91,16 +91,15 @@ def test_predictions_compatibility(cutoff): are consistent with the predictions of the original PET implementation.""" structure = ase.io.read(DATASET_PATH) - atomic_types = set(structure.numbers) dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types=atomic_types, + atomic_types=structure.numbers, targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) capabilities = ModelCapabilities( length_unit="Angstrom", - atomic_types=sorted(atomic_types), + atomic_types=dataset_info.atomic_types, outputs={ "energy": ModelOutput( quantity="energy", @@ -116,7 +115,7 @@ def test_predictions_compatibility(cutoff): hypers["R_CUT"] = cutoff model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) - raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) + raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(dataset_info.atomic_types)) model.set_trained_model(raw_pet) system = systems_to_torch(structure) @@ -142,7 +141,7 @@ def test_predictions_compatibility(cutoff): ARCHITECTURAL_HYPERS = Hypers(DEFAULT_HYPERS["model"]) batch = get_pyg_graphs( [structure], - sorted(atomic_types), + dataset_info.atomic_types, cutoff, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, ARCHITECTURAL_HYPERS.USE_LONG_RANGE, diff --git a/src/metatrain/experimental/pet/tests/test_torchscript.py b/src/metatrain/experimental/pet/tests/test_torchscript.py index c2c9da9e0..df0584cd3 100644 --- a/src/metatrain/experimental/pet/tests/test_torchscript.py +++ b/src/metatrain/experimental/pet/tests/test_torchscript.py @@ -15,7 +15,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) @@ -30,7 +30,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index abb08f7a8..54980c0b6 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -103,7 +103,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.hypers = model_hypers self.dataset_info = dataset_info self.new_outputs = list(dataset_info.targets.keys()) - self.atomic_types = sorted(dataset_info.atomic_types) + self.atomic_types = dataset_info.atomic_types self.soap_calculator = rascaline.torch.SoapPowerSpectrum( radial_basis={"Gto": {}}, **self.hypers["soap"] @@ -198,7 +198,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": # merge old and new dataset info merged_info = self.dataset_info.union(dataset_info) - new_atomic_types = merged_info.atomic_types - self.dataset_info.atomic_types + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] new_targets = merged_info.targets - self.dataset_info.targets if len(new_atomic_types) > 0: @@ -212,7 +214,7 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": self.add_output(output_name) self.dataset_info = merged_info - self.atomic_types = sorted(self.dataset_info.atomic_types) + self.atomic_types = sorted(self.atomic_types) for target_name, target in new_targets.items(): self.outputs[target_name] = ModelOutput( diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index b4b0a429c..9bd9b0e60 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -26,7 +26,7 @@ def test_continue(monkeypatch, tmp_path): target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) model = SoapBpnn(MODEL_HYPERS, dataset_info) output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index 3df1542de..cc41a360c 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -18,7 +18,7 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 2ee592caf..45dae1a8d 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -14,7 +14,7 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) @@ -37,7 +37,7 @@ def test_prediction_subset_atoms(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) @@ -103,7 +103,7 @@ def test_output_last_layer_features(): """Tests that the model can output its last layer features.""" dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) @@ -174,7 +174,7 @@ def test_output_per_atom(): """Tests that the model can output per-atom quantities.""" dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py index 3128baf8c..2b5835b74 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py @@ -15,7 +15,7 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 44bb84e58..07e3871ae 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -25,7 +25,7 @@ def test_regression_init(): targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -76,7 +76,7 @@ def test_regression_train(): hypers["training"]["num_epochs"] = 2 dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 73631b1a4..4d6b89898 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -14,7 +14,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -38,7 +38,7 @@ def test_torchscript_with_identity(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) hypers = copy.deepcopy(MODEL_HYPERS) @@ -64,7 +64,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", - atomic_types={1, 6, 7, 8}, + atomic_types=[1, 6, 7, 8], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 0aafc7891..189bbe42d 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -2,8 +2,7 @@ import math import warnings from collections import UserDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import metatensor.learn import torch @@ -14,7 +13,6 @@ from ..units import get_gradient_units -@dataclass class TargetInfo: """A class that contains information about a target. @@ -22,21 +20,50 @@ class TargetInfo: :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to an empty string ``""``. :param per_atom: Whether the target is a per-atom quantity. - :param gradients: Set of gradients of the target that are defined in the current - dataset. Examples are ``"positions"`` or ``"strain"``. + :param gradients: List containing the gradient names of the target that are present + in the target. Examples are ``"positions"`` or ``"strain"``. ``gradients`` will + be stored as a sorted list of **unique** gradients. """ - quantity: str - unit: str = "" - per_atom: bool = False - gradients: Set[str] = field(default_factory=set) - - def __post_init__(self): - if self.unit is None: - self.unit = "" + def __init__( + self, + quantity: str, + unit: Union[None, str] = "", + per_atom: bool = False, + gradients: Optional[List[str]] = None, + ): + self.quantity = quantity + self.unit = unit if unit is not None else "" + self.per_atom = per_atom + self._gradients = set(gradients) if gradients is not None else set() + + @property + def gradients(self) -> List[str]: + """Sorted and unique list of gradient names.""" + return sorted(self._gradients) + + @gradients.setter + def gradients(self, value: List[str]): + self._gradients = set(value) + + def __repr__(self): + return ( + f"TargetInfo(quantity={self.quantity!r}, unit={self.unit!r}, " + f"per_atom={self.per_atom!r}, gradients={self.gradients!r})" + ) - # For compatibility with list convert to set - self.gradients = set(self.gradients) + def __eq__(self, other): + if not isinstance(other, TargetInfo): + raise NotImplementedError( + "Comparison between a TargetInfo instance and a " + f"{type(other).__name__} instance is not implemented." + ) + return ( + self.quantity == other.quantity + and self.unit == other.unit + and self.per_atom == other.per_atom + and self._gradients == other._gradients + ) def copy(self) -> "TargetInfo": """Return a shallow copy of the TargetInfo.""" @@ -70,7 +97,7 @@ def update(self, other: "TargetInfo") -> None: f"({self.per_atom} != {other.per_atom})" ) - self.gradients = self.gradients.union(other.gradients) + self.gradients = self.gradients + other.gradients def union(self, other: "TargetInfo") -> "TargetInfo": """Return the union of this instance with ``other``.""" @@ -139,33 +166,53 @@ def difference(self, other: "TargetInfoDict") -> "TargetInfoDict": return TargetInfoDict(**{key: self[key] for key in new_keys}) -@dataclass class DatasetInfo: """A class that contains information about datasets. - This dataclass is used to communicate additional dataset details to the + This class is used to communicate additional dataset details to the training functions of the individual models. - :param length_unit: Unit of length used in the dataset. - :param atomic_types: Unordered set of all atomic types present in the dataset. - - .. note:: - - ``atomic_types`` is a :py:class:`set` and **not ordered**. Use - :py:func:`sorted` for an ordered :py:class:`list`. + :param length_unit: Unit of length used in the dataset. Examples are ``"angstrom"`` + or ``"nanometer"``. + :param atomic_types: List containing all integer atomic types present in the + dataset. ``atomic_types`` will be stored as a sorted list of **unique** atomic + types. :param targets: Information about targets in the dataset. """ - length_unit: str - atomic_types: Set[int] - targets: TargetInfoDict - - def __post_init__(self): - if self.length_unit is None: - self.length_unit = "" + def __init__( + self, length_unit: str, atomic_types: List[int], targets: TargetInfoDict + ): + self.length_unit = length_unit if length_unit is not None else "" + self._atomic_types = set(atomic_types) + self.targets = targets + + @property + def atomic_types(self) -> List[int]: + """Sorted list of unique integer atomic types.""" + return sorted(self._atomic_types) + + @atomic_types.setter + def atomic_types(self, value: List[int]): + self._atomic_types = set(value) + + def __repr__(self): + return ( + f"DatasetInfo(length_unit={self.length_unit!r}, " + f"atomic_types={self.atomic_types!r}, targets={self.targets!r})" + ) - # For compatibility with list convert to set - self.atomic_types = set(self.atomic_types) + def __eq__(self, other): + if not isinstance(other, DatasetInfo): + raise NotImplementedError( + "Comparison between a DatasetInfo instance and a " + f"{type(other).__name__} instance is not implemented." + ) + return ( + self.length_unit == other.length_unit + and self._atomic_types == other._atomic_types + and self.targets == other.targets + ) def copy(self) -> "DatasetInfo": """Return a shallow copy of the DatasetInfo.""" @@ -186,8 +233,8 @@ def update(self, other: "DatasetInfo") -> None: f"({self.length_unit} != {other.length_unit})" ) - self.atomic_types = self.atomic_types.union(other.atomic_types) - self.targets = self.targets.union(other.targets) + self.atomic_types = self.atomic_types + other.atomic_types + self.targets.update(other.targets) def union(self, other: "DatasetInfo") -> "DatasetInfo": """Return the union of this instance with ``other``.""" @@ -325,7 +372,7 @@ def _get_dataset_stats( return stats -def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> Set[int]: +def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> List[int]: """List of all atomic types present in a dataset or list of datasets. :param datasets: the dataset, or list of datasets @@ -341,7 +388,7 @@ def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> Set[int]: system = dataset[index]["system"] types += system.types.tolist() - return set(types) + return sorted(set(types)) def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]: diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index c09b4d427..b64ee37f5 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -189,7 +189,7 @@ def read_targets( standard_outputs_list = ["energy"] for target_key, target in conf.items(): - target_info_gradients = set() + target_info_gradients: List[str] = [] if target_key not in standard_outputs_list and not target_key.startswith( "mtt::" @@ -227,7 +227,7 @@ def read_targets( parameter="positions", gradient=position_gradient ) - target_info_gradients.add("positions") + target_info_gradients.append("positions") if target["stress"] and target["virial"]: raise ValueError("Cannot use stress and virial at the same time!") @@ -252,7 +252,7 @@ def read_targets( for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) - target_info_gradients.add("strain") + target_info_gradients.append("strain") if target["virial"]: try: @@ -274,7 +274,7 @@ def read_targets( for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) - target_info_gradients.add("strain") + target_info_gradients.append("strain") else: raise ValueError( f"Quantity: {target['quantity']!r} is not supported. Choose 'energy'." diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index fcca9000d..989c98662 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -27,18 +27,42 @@ def test_target_info_default(): assert target_info.quantity == "energy" assert target_info.unit == "kcal/mol" assert target_info.per_atom is False - assert target_info.gradients == set() + assert target_info.gradients == [] + + expected = ( + "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=False, gradients=[])" + ) + assert target_info.__repr__() == expected def test_target_info_gradients(): target_info = TargetInfo( - quantity="energy", unit="kcal/mol", per_atom=True, gradients=["positions"] + quantity="energy", + unit="kcal/mol", + per_atom=True, + gradients=["positions", "positions"], ) assert target_info.quantity == "energy" assert target_info.unit == "kcal/mol" assert target_info.per_atom is True - assert target_info.gradients == {"positions"} + assert target_info.gradients == ["positions"] + + expected = ( + "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=True, " + "gradients=['positions'])" + ) + assert target_info.__repr__() == expected + + +def test_list_gradients(): + info1 = TargetInfo(quantity="energy", unit="eV") + + info1.gradients = ["positions"] + assert info1.gradients == ["positions"] + + info1.gradients += ["strain"] + assert info1.gradients == ["positions", "strain"] def test_unit_none_conversion(): @@ -49,32 +73,51 @@ def test_unit_none_conversion(): def test_length_unit_none_conversion(): dataset_info = DatasetInfo( length_unit=None, - atomic_types={1, 2, 3}, + atomic_types=[1, 2, 3], targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")), ) assert dataset_info.length_unit == "" def test_target_info_copy(): - info = TargetInfo(quantity="energy", unit="eV", gradients={"positions"}) + info = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) copy = info.copy() assert copy == info assert copy is not info +def test_target_info_eq(): + info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) + + assert info1 == info1 + assert info1 != info2 + + +def test_target_info_eq_error(): + info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + + match = ( + "Comparison between a TargetInfo instance and a list instance is not " + "implemented." + ) + with pytest.raises(NotImplementedError, match=match): + _ = info == [1, 2, 3] + + def test_target_info_update(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - info2 = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + info1 = TargetInfo(quantity="energy", unit="eV", gradients=["strain", "aaa"]) + info2 = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) info1.update(info2) - assert set(info1.gradients) == {"position", "strain"} + assert info1.gradients == ["aaa", "positions", "strain"] def test_target_info_union(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - info2 = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) info_new = info1.union(info2) assert isinstance(info_new, TargetInfo) - assert set(info_new.gradients) == {"position", "strain"} + assert info_new.gradients == ["position", "strain"] def test_target_info_update_non_matching_quantity(): @@ -103,18 +146,18 @@ def test_target_info_update_non_matching_per_atom(): def test_target_info_dict_setitem_new_entry(): tid = TargetInfoDict() - info = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) + info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) tid["energy"] = info assert tid["energy"] == info def test_target_info_dict_setitem_update_entry(): tid = TargetInfoDict() - info1 = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - info2 = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) tid["energy"] = info1 tid["energy"] = info2 - assert set(tid["energy"].gradients) == {"position", "strain"} + assert tid["energy"].gradients == ["position", "strain"] def test_target_info_dict_setitem_value_error(): @@ -125,10 +168,10 @@ def test_target_info_dict_setitem_value_error(): def test_target_info_dict_union(): tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) + tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) merged = tid1.union(tid2) assert merged["energy"] == tid1["energy"] @@ -137,11 +180,11 @@ def test_target_info_dict_union(): def test_target_info_dict_merge_error(): tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) + tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) tid2 = TargetInfoDict() tid2["energy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients={"strain"} + quantity="energy", unit="kcal/mol", gradients=["strain"] ) match = r"Can't update TargetInfo with a different `unit`: \(eV != kcal/mol\)" @@ -151,11 +194,11 @@ def test_target_info_dict_merge_error(): def test_target_info_dict_intersection(): tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) intersection = tid1.intersection(tid2) assert len(intersection) == 1 @@ -168,12 +211,12 @@ def test_target_info_dict_intersection(): def test_target_info_dict_intersection_error(): tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) tid2 = TargetInfoDict() tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients={"strain"} + quantity="energy", unit="kcal/mol", gradients=["strain"] ) match = ( @@ -186,14 +229,13 @@ def test_target_info_dict_intersection_error(): def test_target_info_dict_difference(): - # TODO test `-` operator tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients={"position"}) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients={"strain"}) + tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) + tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) tid2 = TargetInfoDict() tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients={"strain"} + quantity="energy", unit="kcal/mol", gradients=["strain"] ) difference = tid1.difference(tid2) @@ -211,22 +253,43 @@ def test_dataset_info(): targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") dataset_info = DatasetInfo( - length_unit="angstrom", atomic_types={1, 2, 3}, targets=targets + length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets ) assert dataset_info.length_unit == "angstrom" - assert dataset_info.atomic_types == {1, 2, 3} + assert dataset_info.atomic_types == [1, 2, 3] assert dataset_info.targets["energy"].quantity == "energy" assert dataset_info.targets["energy"].unit == "kcal/mol" assert dataset_info.targets["mtt::U0"].quantity == "energy" assert dataset_info.targets["mtt::U0"].unit == "kcal/mol" + expected = ( + "DatasetInfo(length_unit='angstrom', atomic_types=[1, 2, 3], " + f"targets={targets})" + ) + assert dataset_info.__repr__() == expected + + +def test_set_atomic_types(): + targets = TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")) + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") + + dataset_info = DatasetInfo( + length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets + ) + + dataset_info.atomic_types = [5, 4, 1] + assert dataset_info.atomic_types == [1, 4, 5] + + dataset_info.atomic_types += [7, 1] + assert dataset_info.atomic_types == [1, 4, 5, 7] + def test_dataset_info_copy(): targets = TargetInfoDict() targets["energy"] = TargetInfo(quantity="energy", unit="eV") targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") - info = DatasetInfo(length_unit="angstrom", atomic_types={1, 6}, targets=targets) + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) copy = info.copy() @@ -238,15 +301,15 @@ def test_dataset_info_update(): targets = TargetInfoDict() targets["energy"] = TargetInfo(quantity="energy", unit="eV") - info = DatasetInfo(length_unit="angstrom", atomic_types={1, 6}, targets=targets) + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") - info2 = DatasetInfo(length_unit="angstrom", atomic_types={8}, targets=targets2) + info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) info.update(info2) - assert info.atomic_types == {1, 6, 8} + assert info.atomic_types == [1, 6, 8] assert info.targets["energy"] == targets["energy"] assert info.targets["forces"] == targets2["forces"] @@ -255,12 +318,12 @@ def test_dataset_info_update_non_matching_length_unit(): targets = TargetInfoDict() targets["energy"] = TargetInfo(quantity="energy", unit="eV") - info = DatasetInfo(length_unit="angstrom", atomic_types={1, 6}, targets=targets) + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") - info2 = DatasetInfo(length_unit="nanometer", atomic_types={8}, targets=targets2) + info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) match = ( r"Can't update DatasetInfo with a different `length_unit`: " @@ -271,16 +334,44 @@ def test_dataset_info_update_non_matching_length_unit(): info.update(info2) +def test_dataset_info_eq(): + targets = TargetInfoDict() + targets["energy"] = TargetInfo(quantity="energy", unit="eV") + + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) + + targets2 = targets.copy() + targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) + + assert info == info + assert info != info2 + + +def test_dataset_info_eq_error(): + targets = TargetInfoDict() + targets["energy"] = TargetInfo(quantity="energy", unit="eV") + + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) + + match = ( + "Comparison between a DatasetInfo instance and a list instance is not " + "implemented." + ) + with pytest.raises(NotImplementedError, match=match): + _ = info == [1, 2, 3] + + def test_dataset_info_update_different_target_info(): targets = TargetInfoDict() targets["energy"] = TargetInfo(quantity="energy", unit="eV") - info = DatasetInfo(length_unit="angstrom", atomic_types={1, 6}, targets=targets) + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = TargetInfoDict() targets2["energy"] = TargetInfo(quantity="energy", unit="eV/Angstrom") - info2 = DatasetInfo(length_unit="angstrom", atomic_types={8}, targets=targets2) + info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) match = r"Can't update TargetInfo with a different `unit`: \(eV != eV/Angstrom\)" with pytest.raises(ValueError, match=match): @@ -292,19 +383,19 @@ def test_dataset_info_union(): targets = TargetInfoDict() targets["energy"] = TargetInfo(quantity="energy", unit="eV") targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") - info = DatasetInfo(length_unit="angstrom", atomic_types={1, 6}, targets=targets) + info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) other_targets = targets.copy() other_targets["mtt::stress"] = TargetInfo(quantity="mtt::stress", unit="GPa") other_info = DatasetInfo( - length_unit="angstrom", atomic_types={1}, targets=other_targets + length_unit="angstrom", atomic_types=[1], targets=other_targets ) union = info.union(other_info) assert union.length_unit == "angstrom" - assert union.atomic_types == {1, 6} + assert union.atomic_types == [1, 6] assert union.targets == other_targets @@ -370,9 +461,9 @@ def test_get_atomic_types(): dataset = Dataset({"system": systems, **targets}) dataset_2 = Dataset({"system": systems_2, **targets_2}) - assert get_atomic_types(dataset) == {1, 6, 7, 8} - assert get_atomic_types(dataset_2) == {1, 6, 8} - assert get_atomic_types([dataset, dataset_2]) == {1, 6, 7, 8} + assert get_atomic_types(dataset) == [1, 6, 7, 8] + assert get_atomic_types(dataset_2) == [1, 6, 8] + assert get_atomic_types([dataset, dataset_2]) == [1, 6, 7, 8] def test_get_all_targets(): @@ -547,7 +638,7 @@ def test_get_stats(): dataset_info = DatasetInfo( length_unit="angstrom", - atomic_types={1, 6}, + atomic_types=[1, 6], targets={ "mtt::U0": TargetInfo(quantity="energy", unit="eV"), "energy": TargetInfo(quantity="energy", unit="eV"), diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 69d41f81a..8a9a57bab 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -193,7 +193,7 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): assert target_info.quantity == target_section["quantity"] assert target_info.unit == target_section["unit"] assert target_info.per_atom is False - assert target_info.gradients == {"positions", "strain"} + assert target_info.gradients == ["positions", "strain"] assert type(target_list) is list for target in target_list: