-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3882faa
commit 34036f4
Showing
4 changed files
with
556 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .dataset import Dataset, collate_fn # noqa: F401 | ||
from .readers import read_structures, read_targets # noqa: F401 | ||
from .writers import write_predictions # noqa: F401 | ||
from .combine_dataloaders import combine_dataloaders # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import itertools | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class CombinedIterableDataset(torch.utils.data.IterableDataset): | ||
""" | ||
Combines multiple dataloaders into a single iterable dataset. | ||
This is useful for combining multiple datasets into a single dataloader | ||
and learning from all of them simultaneously. | ||
""" | ||
|
||
def __init__(self, dataloaders, shuffle): | ||
self.dataloaders = dataloaders | ||
self.shuffle = shuffle | ||
self.indices = self._create_indices() | ||
|
||
def _create_indices(self): | ||
# Create a list of (dataloader_idx, idx) tuples | ||
# for all indices in all dataloaders | ||
indices = [ | ||
(i, dl_idx) | ||
for dl_idx, dl in enumerate(self.dataloaders) | ||
for i in range(len(dl)) | ||
] | ||
|
||
# Shuffle the indices if requested | ||
if self.shuffle: | ||
np.random.shuffle(indices) | ||
return indices | ||
|
||
def __iter__(self): | ||
for idx, dataloader_idx in self.indices: | ||
yield next(itertools.islice(self.dataloaders[dataloader_idx], idx, None)) | ||
|
||
def __len__(self): | ||
return len(self.indices) | ||
|
||
|
||
def combine_dataloaders(*dataloaders, shuffle=True): | ||
combined_dataset = CombinedIterableDataset(dataloaders, shuffle) | ||
return torch.utils.data.DataLoader(combined_dataset, batch_size=None) |
Oops, something went wrong.