diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index df79773da..8db135c86 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -72,7 +72,7 @@ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems ] -dataset = Dataset({"system": qm9_systems, **targets}) +dataset = Dataset.from_dict({"system": qm9_systems, **targets}) # We also load a single ethanol molecule on which we will compute properties. # This system is loaded without targets, as we are only interested in the LPR diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 4df6572a9..93adb1627 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -296,7 +296,7 @@ def eval_model( gradients=gradients, ) - eval_dataset = Dataset({"system": eval_systems, **eval_targets}) + eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets}) # Evaluate the model try: diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index fe8dec96a..4dbc6ed0b 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -92,7 +92,7 @@ def test_regression_train(): } } targets, target_info_dict = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() diff --git a/src/metatrain/experimental/alchemical_model/utils/normalize.py b/src/metatrain/experimental/alchemical_model/utils/normalize.py index addc85499..494aa0388 100644 --- a/src/metatrain/experimental/alchemical_model/utils/normalize.py +++ b/src/metatrain/experimental/alchemical_model/utils/normalize.py @@ -18,10 +18,10 @@ def get_average_number_of_atoms( """ average_number_of_atoms = [] for dataset in datasets: - dtype = dataset[0]["system"].positions.dtype + dtype = dataset[0].system.positions.dtype num_atoms = [] for i in range(len(dataset)): - system = dataset[i]["system"] + system = dataset[i].system num_atoms.append(len(system)) average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype))) return torch.tensor(average_number_of_atoms) @@ -39,9 +39,9 @@ def get_average_number_of_neighbors( average_number_of_neighbors = [] for dataset in datasets: num_neighbor = [] - dtype = dataset[0]["system"].positions.dtype + dtype = dataset[0].system.positions.dtype for i in range(len(dataset)): - system = dataset[i]["system"] + system = dataset[i].system known_neighbor_lists = system.known_neighbor_lists() if len(known_neighbor_lists) == 0: raise ValueError(f"system {system} does not have a neighbor list") @@ -94,4 +94,4 @@ def remove_composition_from_dataset( new_systems.append(system) new_properties.append(property) - return Dataset({"system": new_systems, property_name: new_properties}) + return Dataset.from_dict({"system": new_systems, property_name: new_properties}) diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 7cc1958fb..24e5c03b3 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -54,7 +54,9 @@ def test_ethanol_regression_train_and_invariance(): } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems[:2], "energy": targets["energy"][:2]}) + dataset = Dataset.from_dict( + {"system": systems[:2], "energy": targets["energy"][:2]} + ) hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 30 diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index e4d2dda1a..e2a2ee72c 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -55,7 +55,7 @@ def test_regression_train_and_invariance(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) target_info_dict = TargetInfoDict() target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") @@ -132,7 +132,7 @@ def test_ethanol_regression_train_and_invariance(): } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "energy": targets["energy"]}) + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 900 diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index f0680fd40..967a83353 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -36,7 +36,7 @@ def test_torchscript(): # for system in systems: # system.types = torch.ones(len(system.types), dtype=torch.int32) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() gap = GAP(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index 9bd9b0e60..dac96cd94 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -44,7 +44,7 @@ def test_continue(monkeypatch, tmp_path): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 0 diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 07e3871ae..0663da485 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -70,7 +70,7 @@ def test_regression_train(): } } targets, target_info_dict = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 2 diff --git a/src/metatrain/utils/data/__init__.py b/src/metatrain/utils/data/__init__.py index b00479072..7387e0024 100644 --- a/src/metatrain/utils/data/__init__.py +++ b/src/metatrain/utils/data/__init__.py @@ -7,7 +7,6 @@ get_all_targets, collate_fn, check_datasets, - group_and_join, ) from .readers import ( # noqa: F401 read_energy, diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index a4225c02c..a427d4302 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -3,9 +3,9 @@ from collections import UserDict from typing import Any, Dict, List, Optional, Tuple, Union -import metatensor.learn import numpy as np import torch +from metatensor.learn.data import Dataset, group_and_join from metatensor.torch import TensorMap from ..external_naming import to_external_name @@ -242,47 +242,6 @@ def union(self, other: "DatasetInfo") -> "DatasetInfo": return new -class Dataset: - """A version of the `metatensor.learn.Dataset` class that allows for - the use of `mtt::` prefixes in the keys of the dictionary. See - https://github.com/lab-cosmo/metatensor/issues/621. - - It is important to note that, instead of named tuples, this class - accepts and returns dictionaries. - - :param dict: A dictionary with the data to be stored in the dataset. - """ - - def __init__(self, dict: Dict): - - new_dict = {} - for key, value in dict.items(): - key = key.replace("mtt::", "mtt_") - new_dict[key] = value - - self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict) - - def __getitem__(self, idx: int) -> Dict: - - mts_dataset_item = self.mts_learn_dataset[idx]._asdict() - new_dict = {} - for key, value in mts_dataset_item.items(): - key = key.replace("mtt_", "mtt::") - new_dict[key] = value - - return new_dict - - def __len__(self) -> int: - return len(self.mts_learn_dataset) - - def __iter__(self): - for i in range(len(self)): - yield self[i] - - def get_stats(self, dataset_info: DatasetInfo) -> str: - return _get_dataset_stats(self, dataset_info) - - class Subset(torch.utils.data.Subset): """ A version of `torch.utils.data.Subset` containing a `get_stats` method @@ -306,7 +265,7 @@ def _get_dataset_stats( # target_names will be used to store names of the targets, # along with their gradients target_names = [] - for key, tensor_map in dataset[0].items(): + for key, tensor_map in dataset[0]._asdict().items(): if key == "system": continue target_names.append(key) @@ -408,8 +367,8 @@ def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]: target_names = [] for dataset in datasets: for sample in dataset: - sample.pop("system") # system not needed - target_names += list(sample.keys()) + # system not needed + target_names += [key for key in sample._asdict().keys() if key != "system"] return sorted(set(target_names)) @@ -422,6 +381,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]] """ collated_targets = group_and_join(batch) + collated_targets = collated_targets._asdict() systems = collated_targets.pop("system") return systems, collated_targets @@ -441,15 +401,15 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]): or targets that are not present in the training set """ # Check that system `dtypes` are consistent within datasets - desired_dtype = train_datasets[0][0]["system"].positions.dtype + desired_dtype = train_datasets[0][0].system.positions.dtype msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and " for train_dataset in train_datasets: - actual_dtype = train_dataset[0]["system"].positions.dtype + actual_dtype = train_dataset[0].system.positions.dtype if actual_dtype != desired_dtype: raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`") for val_dataset in val_datasets: - actual_dtype = val_dataset[0]["system"].positions.dtype + actual_dtype = val_dataset[0].system.positions.dtype if actual_dtype != desired_dtype: raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`") @@ -515,33 +475,3 @@ def _train_test_random_split( Subset(train_dataset, train_indices), Subset(train_dataset, test_indices), ] - - -def group_and_join( - batch: List[Dict[str, Any]], -) -> Dict[str, Any]: - """ - Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples. - - :param batch: A list of dictionaries, each containing the data for a single sample. - - :returns: A single dictionary with the data fields joined together among all - samples. - """ - data: List[Union[TensorMap, torch.Tensor]] = [] - names = batch[0].keys() - for name, f in zip(names, zip(*(item.values() for item in batch))): - if name == "sample_id": # special case, keep as is - data.append(f) - continue - - if isinstance(f[0], torch.ScriptObject) and f[0]._has_method( - "keys_to_properties" - ): # inferred metatensor.torch.TensorMap type - data.append(metatensor.torch.join(f, axis="samples")) - elif isinstance(f[0], torch.Tensor): # torch.Tensor type - data.append(torch.vstack(f)) - else: # otherwise just keep as a list - data.append(f) - - return {name: value for name, value in zip(names, data)} diff --git a/src/metatrain/utils/data/extract_targets.py b/src/metatrain/utils/data/extract_targets.py index fe39495b1..ee86b29d4 100644 --- a/src/metatrain/utils/data/extract_targets.py +++ b/src/metatrain/utils/data/extract_targets.py @@ -28,6 +28,7 @@ def get_targets_dict( targets_dict = {} for dataset in datasets: targets = next(iter(dataset)) + targets = targets._asdict() targets.pop("system") # system not needed # targets is now a dictionary of TensorMaps diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 2094bae4c..2f95263c5 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -27,6 +27,6 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: reader=options["systems"]["reader"], ) targets, target_info_dictionary = read_targets(conf=options["targets"]) - dataset = Dataset({"system": systems, **targets}) + dataset = Dataset.from_dict({"system": systems, **targets}) return dataset, target_info_dictionary diff --git a/tests/utils/data/test_combine_dataloaders.py b/tests/utils/data/test_combine_dataloaders.py index 6f855059f..7a5f4797a 100644 --- a/tests/utils/data/test_combine_dataloaders.py +++ b/tests/utils/data/test_combine_dataloaders.py @@ -36,7 +36,7 @@ def test_without_shuffling(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn) # will yield 10 batches of 10 @@ -94,7 +94,7 @@ def test_with_shuffling(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) dataloader_qm9 = DataLoader( dataset, batch_size=10, collate_fn=collate_fn, shuffle=True ) diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 989c98662..587532a33 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -418,7 +418,7 @@ def test_dataset(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "energy": targets["energy"]}) + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) dataloader = torch.utils.data.DataLoader( dataset, batch_size=10, collate_fn=collate_fn ) @@ -458,8 +458,8 @@ def test_get_atomic_types(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) assert get_atomic_types(dataset) == [1, 6, 7, 8] assert get_atomic_types(dataset_2) == [1, 6, 8] @@ -497,8 +497,8 @@ def test_get_all_targets(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) assert get_all_targets(dataset) == ["mtt::U0"] assert get_all_targets(dataset_2) == ["energy"] assert get_all_targets([dataset, dataset_2]) == ["energy", "mtt::U0"] @@ -537,19 +537,19 @@ def test_check_datasets(): targets_ethanol, _ = read_targets(OmegaConf.create(conf_ethanol)) # everything ok - train_set = Dataset({"system": systems_qm9, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_qm9}) + train_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) check_datasets([train_set], [val_set]) # extra species in validation dataset - train_set = Dataset({"system": systems_ethanol, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_qm9}) + train_set = Dataset.from_dict({"system": systems_ethanol, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) with pytest.raises(ValueError, match="The validation dataset has a species"): check_datasets([train_set], [val_set]) # extra targets in validation dataset - train_set = Dataset({"system": systems_qm9, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_ethanol}) + train_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_ethanol}) with pytest.raises(ValueError, match="The validation dataset has a target"): check_datasets([train_set], [val_set]) @@ -558,7 +558,9 @@ def test_check_datasets(): targets_qm9_32bit = { k: [v.to(dtype=torch.float32) for v in l] for k, l in targets_qm9.items() } - train_set_32_bit = Dataset({"system": systems_qm9_32bit, **targets_qm9_32bit}) + train_set_32_bit = Dataset.from_dict( + {"system": systems_qm9_32bit, **targets_qm9_32bit} + ) match = ( "`dtype` between datasets is inconsistent, found torch.float64 and " @@ -592,7 +594,7 @@ def test_collate_fn(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) batch = collate_fn([dataset[0], dataset[1], dataset[2]]) @@ -633,8 +635,8 @@ def test_get_stats(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) dataset_info = DatasetInfo( length_unit="angstrom", diff --git a/tests/utils/data/test_get_dataset.py b/tests/utils/data/test_get_dataset.py index 17fda35c7..765f6a62b 100644 --- a/tests/utils/data/test_get_dataset.py +++ b/tests/utils/data/test_get_dataset.py @@ -31,8 +31,8 @@ def test_get_dataset(): dataset, target_info = get_dataset(OmegaConf.create(options)) - assert "system" in dataset[0] - assert "energy" in dataset[0] + dataset[0].system + dataset[0].energy assert "energy" in target_info assert target_info["energy"].quantity == "energy" assert target_info["energy"].unit == "eV" diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py index 1933d8d7d..63a75e646 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_composition.py @@ -61,7 +61,7 @@ def test_calculate_composition_weights(): ) for i, e in enumerate(energies) ] - dataset = Dataset({"system": systems, "energy": energies}) + dataset = Dataset.from_dict({"system": systems, "energy": energies}) weights, atomic_types = calculate_composition_weights(dataset, "energy") diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 8debe8726..f1887ef5b 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -42,7 +42,7 @@ def test_llpr(tmpdir): get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems ] - dataset = Dataset({"system": qm9_systems, **targets}) + dataset = Dataset.from_dict({"system": qm9_systems, **targets}) dataloader = torch.utils.data.DataLoader( dataset, batch_size=10,