Skip to content

Commit

Permalink
Extract species list from dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 6, 2023
1 parent 99ca38b commit e89bdfe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
18 changes: 18 additions & 0 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,21 @@ def collate_fn(batch):
)

return structures, targets


def get_all_species(dataset: Dataset) -> List[int]:
"""
Returns the list of all species present in the dataset.
Args:
dataset: The dataset to get the species from.
Returns:
The list of all species present in the dataset.
"""

species = set()
for structure in dataset.structures:
species.update(structure.species.tolist())

return sorted(species)
16 changes: 15 additions & 1 deletion tests/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch

from metatensor.models.utils.data import Dataset, collate_fn, read_structures, read_targets
from metatensor.models.utils.data import Dataset, collate_fn, read_structures, read_targets, get_all_species


def test_dataset():
Expand All @@ -17,3 +17,17 @@ def test_dataset():

for batch in dataloader:
assert batch[1]["U0"].block().values.shape == (10, 1)


def test_species_list():
"""Tests that the species list is correctly computed."""

dataset_path = os.path.join(os.path.dirname(__file__), "data/qm9_reduced_100.xyz")

structures = read_structures(dataset_path)
targets = read_targets(dataset_path, "U0")

dataset = Dataset(structures, targets)
species_list = get_all_species(dataset)

assert species_list == [1, 6, 7, 8]

0 comments on commit e89bdfe

Please sign in to comment.