Skip to content

Commit

Permalink
Add new functions to the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 16, 2024
1 parent 01a528b commit aa31433
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 11 deletions.
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/combine_dataloaders.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Combining dataloaders
#####################

.. automodule:: metatensor.models.utils.data.combine_dataloaders
:members:
:undoc-members:
:show-inheritance:
9 changes: 2 additions & 7 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@

from ..utils.composition import calculate_composition_weights
from ..utils.compute_loss import compute_model_loss
from ..utils.data import (
Dataset,
canonical_check_datasets,
collate_fn,
combine_dataloaders,
)
from ..utils.data import Dataset, check_datasets, collate_fn, combine_dataloaders
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model
Expand All @@ -29,7 +24,7 @@ def train(
output_dir: str = ".",
):
# Perform canonical checks on the datasets:
canonical_check_datasets(
check_datasets(
train_datasets,
validation_datasets,
model_capabilities,
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .dataset import Dataset, collate_fn, canonical_check_datasets # noqa: F401
from .dataset import Dataset, collate_fn, check_datasets # noqa: F401
from .readers import read_structures, read_targets # noqa: F401
from .writers import write_predictions # noqa: F401
from .combine_dataloaders import combine_dataloaders # noqa: F401
9 changes: 7 additions & 2 deletions src/metatensor/models/utils/data/combine_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
class CombinedIterableDataset(torch.utils.data.IterableDataset):
"""
Combines multiple dataloaders into a single iterable dataset.
This is useful for combining multiple datasets into a single dataloader
and learning from all of them simultaneously.
This is useful for combining multiple dataloaders into a single
dataloader. The new dataloader can be shuffled or not.
"""

def __init__(self, dataloaders, shuffle):
Expand Down Expand Up @@ -44,6 +44,11 @@ def combine_dataloaders(
):
"""
Combines multiple dataloaders into a single dataloader.
:param dataloaders: list of dataloaders to combine
:param shuffle: whether to shuffle the combined dataloader
:return: combined dataloader
"""
combined_dataset = CombinedIterableDataset(dataloaders, shuffle)
return torch.utils.data.DataLoader(combined_dataset, batch_size=None)
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def collate_fn(batch):
return structures, targets


def canonical_check_datasets(
def check_datasets(
train_datasets: List[Dataset],
validation_datasets: List[Dataset],
capabilities: ModelCapabilities,
Expand Down

0 comments on commit aa31433

Please sign in to comment.