diff --git a/docs/src/dev-docs/utils-api.rst b/docs/src/dev-docs/utils-api.rst index c45ab264e..cea1cacb0 100644 --- a/docs/src/dev-docs/utils-api.rst +++ b/docs/src/dev-docs/utils-api.rst @@ -5,3 +5,9 @@ This is the API for the ``utils`` module of ``metatensor-models``. .. automodule:: metatensor_models.utils.data :members: + +.. automodule:: metatensor_models.utils.data.dataset + :members: + +.. automodule:: metatensor_models.utils.data.readers + :members: diff --git a/src/metatensor_models/utils/data/dataset.py b/src/metatensor_models/utils/data/dataset.py index f94754a08..16cccb6a8 100644 --- a/src/metatensor_models/utils/data/dataset.py +++ b/src/metatensor_models/utils/data/dataset.py @@ -10,8 +10,11 @@ class Dataset(torch.utils.data.Dataset): def __init__( self, structures: List[rascaline.torch.System], targets: Dict[str, TensorMap] ): - """Creates a dataset from a list of `rascaline.torch.System`s and - a list of dictionaries of `TensorMap`s.""" + """ + Creates a dataset from a list of `rascaline.torch.System` objects + and a dictionary of targets where the keys are strings and the + values are `TensorMap` objects. + """ for tensor_map in targets.values(): n_structures = ( @@ -35,8 +38,10 @@ def __len__(self): def __getitem__(self, index): """ Generates one sample of data. + Args: index: The index of the item in the dataset. + Returns: A tuple containing the structure and targets for the given index. """ @@ -59,9 +64,11 @@ def __getitem__(self, index): def collate_fn(batch): """ Creates a batch from a list of samples. + Args: batch: A list of samples, where each sample is a tuple containing a structure and targets. + Returns: A tuple containing the structures and targets for the batch. """