From aa31433abec71c199387688633eca6b6ae865e9e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 16 Jan 2024 12:32:47 +0100 Subject: [PATCH] Add new functions to the documentation --- docs/src/dev-docs/utils/combine_dataloaders.rst | 7 +++++++ src/metatensor/models/soap_bpnn/train.py | 9 ++------- src/metatensor/models/utils/data/__init__.py | 2 +- src/metatensor/models/utils/data/combine_dataloaders.py | 9 +++++++-- src/metatensor/models/utils/data/dataset.py | 2 +- 5 files changed, 18 insertions(+), 11 deletions(-) create mode 100644 docs/src/dev-docs/utils/combine_dataloaders.rst diff --git a/docs/src/dev-docs/utils/combine_dataloaders.rst b/docs/src/dev-docs/utils/combine_dataloaders.rst new file mode 100644 index 000000000..0aff64b69 --- /dev/null +++ b/docs/src/dev-docs/utils/combine_dataloaders.rst @@ -0,0 +1,7 @@ +Combining dataloaders +##################### + +.. automodule:: metatensor.models.utils.data.combine_dataloaders + :members: + :undoc-members: + :show-inheritance: diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index 30bccbed6..2c1dd9354 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -7,12 +7,7 @@ from ..utils.composition import calculate_composition_weights from ..utils.compute_loss import compute_model_loss -from ..utils.data import ( - Dataset, - canonical_check_datasets, - collate_fn, - combine_dataloaders, -) +from ..utils.data import Dataset, check_datasets, collate_fn, combine_dataloaders from ..utils.loss import TensorMapDictLoss from ..utils.model_io import save_model from .model import DEFAULT_HYPERS, Model @@ -29,7 +24,7 @@ def train( output_dir: str = ".", ): # Perform canonical checks on the datasets: - canonical_check_datasets( + check_datasets( train_datasets, validation_datasets, model_capabilities, diff --git a/src/metatensor/models/utils/data/__init__.py b/src/metatensor/models/utils/data/__init__.py index 6eed423be..ab7a34858 100644 --- a/src/metatensor/models/utils/data/__init__.py +++ b/src/metatensor/models/utils/data/__init__.py @@ -1,4 +1,4 @@ -from .dataset import Dataset, collate_fn, canonical_check_datasets # noqa: F401 +from .dataset import Dataset, collate_fn, check_datasets # noqa: F401 from .readers import read_structures, read_targets # noqa: F401 from .writers import write_predictions # noqa: F401 from .combine_dataloaders import combine_dataloaders # noqa: F401 diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 9fe069e92..e5b1afa95 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -8,8 +8,8 @@ class CombinedIterableDataset(torch.utils.data.IterableDataset): """ Combines multiple dataloaders into a single iterable dataset. - This is useful for combining multiple datasets into a single dataloader - and learning from all of them simultaneously. + This is useful for combining multiple dataloaders into a single + dataloader. The new dataloader can be shuffled or not. """ def __init__(self, dataloaders, shuffle): @@ -44,6 +44,11 @@ def combine_dataloaders( ): """ Combines multiple dataloaders into a single dataloader. + + :param dataloaders: list of dataloaders to combine + :param shuffle: whether to shuffle the combined dataloader + + :return: combined dataloader """ combined_dataset = CombinedIterableDataset(dataloaders, shuffle) return torch.utils.data.DataLoader(combined_dataset, batch_size=None) diff --git a/src/metatensor/models/utils/data/dataset.py b/src/metatensor/models/utils/data/dataset.py index 84578922a..c915df599 100644 --- a/src/metatensor/models/utils/data/dataset.py +++ b/src/metatensor/models/utils/data/dataset.py @@ -95,7 +95,7 @@ def collate_fn(batch): return structures, targets -def canonical_check_datasets( +def check_datasets( train_datasets: List[Dataset], validation_datasets: List[Dataset], capabilities: ModelCapabilities,