From e89bdfe10aa6d3663492ce4531de35e5f6dcd7c7 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 6 Dec 2023 04:04:57 +0100 Subject: [PATCH] Extract species list from dataset --- src/metatensor/models/utils/data/dataset.py | 18 ++++++++++++++++++ tests/data.py | 16 +++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/metatensor/models/utils/data/dataset.py b/src/metatensor/models/utils/data/dataset.py index 4f76a9354..b8819f818 100644 --- a/src/metatensor/models/utils/data/dataset.py +++ b/src/metatensor/models/utils/data/dataset.py @@ -81,3 +81,21 @@ def collate_fn(batch): ) return structures, targets + + +def get_all_species(dataset: Dataset) -> List[int]: + """ + Returns the list of all species present in the dataset. + + Args: + dataset: The dataset to get the species from. + + Returns: + The list of all species present in the dataset. + """ + + species = set() + for structure in dataset.structures: + species.update(structure.species.tolist()) + + return sorted(species) diff --git a/tests/data.py b/tests/data.py index 00d5ec542..96eee5217 100644 --- a/tests/data.py +++ b/tests/data.py @@ -1,7 +1,7 @@ import os import torch -from metatensor.models.utils.data import Dataset, collate_fn, read_structures, read_targets +from metatensor.models.utils.data import Dataset, collate_fn, read_structures, read_targets, get_all_species def test_dataset(): @@ -17,3 +17,17 @@ def test_dataset(): for batch in dataloader: assert batch[1]["U0"].block().values.shape == (10, 1) + + +def test_species_list(): + """Tests that the species list is correctly computed.""" + + dataset_path = os.path.join(os.path.dirname(__file__), "data/qm9_reduced_100.xyz") + + structures = read_structures(dataset_path) + targets = read_targets(dataset_path, "U0") + + dataset = Dataset(structures, targets) + species_list = get_all_species(dataset) + + assert species_list == [1, 6, 7, 8]