Skip to content

Commit

Permalink
Create a combined dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 4, 2024
1 parent 3882faa commit 34036f4
Show file tree
Hide file tree
Showing 4 changed files with 556 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/metatensor/models/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dataset import Dataset, collate_fn # noqa: F401
from .readers import read_structures, read_targets # noqa: F401
from .writers import write_predictions # noqa: F401
from .combine_dataloaders import combine_dataloaders # noqa: F401
44 changes: 44 additions & 0 deletions src/metatensor/models/utils/data/combine_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import itertools

import numpy as np
import torch


class CombinedIterableDataset(torch.utils.data.IterableDataset):
"""
Combines multiple dataloaders into a single iterable dataset.
This is useful for combining multiple datasets into a single dataloader
and learning from all of them simultaneously.
"""

def __init__(self, dataloaders, shuffle):
self.dataloaders = dataloaders
self.shuffle = shuffle
self.indices = self._create_indices()

def _create_indices(self):
# Create a list of (dataloader_idx, idx) tuples
# for all indices in all dataloaders
indices = [
(i, dl_idx)
for dl_idx, dl in enumerate(self.dataloaders)
for i in range(len(dl))
]

# Shuffle the indices if requested
if self.shuffle:
np.random.shuffle(indices)
return indices

def __iter__(self):
for idx, dataloader_idx in self.indices:
yield next(itertools.islice(self.dataloaders[dataloader_idx], idx, None))

def __len__(self):
return len(self.indices)


def combine_dataloaders(*dataloaders, shuffle=True):
combined_dataset = CombinedIterableDataset(dataloaders, shuffle)
return torch.utils.data.DataLoader(combined_dataset, batch_size=None)
Loading

0 comments on commit 34036f4

Please sign in to comment.