-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] Add
DataFrame.to_torch_map_dataset
and `.to_torch_iter_datas…
…et`. (#1086) Adds two new top-level APIs to DataFrame: to_torch_map_dataset and to_torch_iter_dataset, that returns respective PyTorch datasets (TODO: make these DataPipes). - `to_torch_map_dataset` will execute the whole dataframe before returning, since it needs random access. - `to_torch_iter_dataset` will return immediately; results are returned via streaming execution. Both are only meant for use in a single-node setting, with Ray Datasets being the recommended data loading abstraction for distributed training. --------- Co-authored-by: Xiayue Charles Lin <[email protected]>
- Loading branch information
1 parent
978db0f
commit e101ed6
Showing
3 changed files
with
97 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Iterable, Iterator | ||
|
||
from loguru import logger | ||
|
||
try: | ||
# When available, subclass from the newer torchdata DataPipes instead of torch Datasets. | ||
import torchdata | ||
|
||
MAP_DATASET_CLASS = torchdata.datapipes.map.MapDataPipe | ||
ITER_DATASET_CLASS = torchdata.datapipes.iter.IterDataPipe | ||
except ImportError: | ||
try: | ||
import torch | ||
|
||
MAP_DATASET_CLASS = torch.utils.data.Dataset | ||
ITER_DATASET_CLASS = torch.utils.data.IterableDataset | ||
except ImportError: | ||
logger.error(f"Error when importing Torch. To use PyTorch features, please install torch.") | ||
raise | ||
|
||
|
||
class DaftTorchDataset(MAP_DATASET_CLASS): # type: ignore | ||
"""A wrapper to create a torch map-style Dataset from a Daft pydict of items.""" | ||
|
||
def __init__(self, data: dict[str, list[Any]], length: int): | ||
self.data = data | ||
self.length = length | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, i): | ||
return {key: vallist[i] for (key, vallist) in self.data.items()} | ||
|
||
|
||
class DaftTorchIterableDataset(ITER_DATASET_CLASS): # type: ignore | ||
"""A thin wrapper to create a torch IterableDataset from an iterable.""" | ||
|
||
def __init__(self, iterable: Iterable[dict[str, Any]]): | ||
self.iterable = iterable | ||
|
||
def __iter__(self) -> Iterator[dict[str, Any]]: | ||
return iter(self.iterable) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters