From e101ed6f6ab11c3b30e0f8839f83a1c8eecee70c Mon Sep 17 00:00:00 2001 From: xcharleslin <4212216+xcharleslin@users.noreply.github.com> Date: Fri, 23 Jun 2023 14:00:41 -0700 Subject: [PATCH] [FEAT] Add `DataFrame.to_torch_map_dataset` and `.to_torch_iter_dataset`. (#1086) Adds two new top-level APIs to DataFrame: to_torch_map_dataset and to_torch_iter_dataset, that returns respective PyTorch datasets (TODO: make these DataPipes). - `to_torch_map_dataset` will execute the whole dataframe before returning, since it needs random access. - `to_torch_iter_dataset` will return immediately; results are returned via streaming execution. Both are only meant for use in a single-node setting, with Ray Datasets being the recommended data loading abstraction for distributed training. --------- Co-authored-by: Xiayue Charles Lin --- daft/dataframe/dataframe.py | 49 ++++++++++++++++++++++++++++++ daft/dataframe/to_torch.py | 45 +++++++++++++++++++++++++++ docs/source/api_docs/dataframe.rst | 4 ++- 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 daft/dataframe/to_torch.py diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index d1e89b11ba..2167ea0069 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -40,6 +40,8 @@ if TYPE_CHECKING: from ray.data.dataset import Dataset as RayDataset from ray import ObjectRef as RayObjectRef + import torch.utils.data.Dataset as TorchDataset + import torch.utils.data.IterableDataset as TorchIterableDataset import pandas as pd import pyarrow as pa import dask @@ -1116,6 +1118,53 @@ def to_pydict(self) -> Dict[str, List[Any]]: assert result is not None return result.to_pydict() + @DataframePublicAPI + def to_torch_map_dataset(self) -> "TorchDataset": + """Convert the current DataFrame into a map-style + `Torch Dataset `__ + for use with PyTorch. + + This method will materialize the entire DataFrame and block on completion. + + Items will be returned in pydict format: a dict of `{"column name": value}` for each row in the data. + + .. NOTE:: + If you do not need random access, you may get better performance out of an IterableDataset, + which streams data items in as soon as they are ready and does not block on full materialization. + + .. NOTE:: + This method returns results locally. + For distributed training, you may want to use ``DataFrame.to_ray_dataset()``. + """ + from daft.dataframe.to_torch import DaftTorchDataset + + return DaftTorchDataset(self.to_pydict(), len(self)) + + @DataframePublicAPI + def to_torch_iter_dataset(self) -> "TorchIterableDataset": + """Convert the current DataFrame into a + `Torch IterableDataset `__ + for use with PyTorch. + + Begins execution of the DataFrame if it is not yet executed. + + Items will be returned in pydict format: a dict of `{"column name": value}` for each row in the data. + + .. NOTE:: + The produced dataset is meant to be used with the single-process DataLoader, + and does not support data sharding hooks for multi-process data loading. + + Do keep in mind that Daft is already using multithreading or multiprocessing under the hood + to compute the data stream that feeds this dataset. + + .. NOTE:: + This method returns results locally. + For distributed training, you may want to use ``DataFrame.to_ray_dataset()``. + """ + from daft.dataframe.to_torch import DaftTorchIterableDataset + + return DaftTorchIterableDataset(self) + @DataframePublicAPI def to_ray_dataset(self) -> "RayDataset": """Converts the current DataFrame to a Ray Dataset which is useful for running distributed ML model training in Ray diff --git a/daft/dataframe/to_torch.py b/daft/dataframe/to_torch.py new file mode 100644 index 0000000000..f3b30b888f --- /dev/null +++ b/daft/dataframe/to_torch.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Any, Iterable, Iterator + +from loguru import logger + +try: + # When available, subclass from the newer torchdata DataPipes instead of torch Datasets. + import torchdata + + MAP_DATASET_CLASS = torchdata.datapipes.map.MapDataPipe + ITER_DATASET_CLASS = torchdata.datapipes.iter.IterDataPipe +except ImportError: + try: + import torch + + MAP_DATASET_CLASS = torch.utils.data.Dataset + ITER_DATASET_CLASS = torch.utils.data.IterableDataset + except ImportError: + logger.error(f"Error when importing Torch. To use PyTorch features, please install torch.") + raise + + +class DaftTorchDataset(MAP_DATASET_CLASS): # type: ignore + """A wrapper to create a torch map-style Dataset from a Daft pydict of items.""" + + def __init__(self, data: dict[str, list[Any]], length: int): + self.data = data + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {key: vallist[i] for (key, vallist) in self.data.items()} + + +class DaftTorchIterableDataset(ITER_DATASET_CLASS): # type: ignore + """A thin wrapper to create a torch IterableDataset from an iterable.""" + + def __init__(self, iterable: Iterable[dict[str, Any]]): + self.iterable = iterable + + def __iter__(self) -> Iterator[dict[str, Any]]: + return iter(self.iterable) diff --git a/docs/source/api_docs/dataframe.rst b/docs/source/api_docs/dataframe.rst index c2f7bed9cd..832663cef0 100644 --- a/docs/source/api_docs/dataframe.rst +++ b/docs/source/api_docs/dataframe.rst @@ -69,7 +69,7 @@ Combining :toctree: doc_gen/dataframe_methods daft.DataFrame.join - daft.DataFrane.concat + daft.DataFrame.concat .. _df-aggregations: @@ -143,6 +143,8 @@ Integrations :toctree: doc_gen/dataframe_methods daft.DataFrame.to_pandas + daft.DataFrame.to_torch_map_dataset + daft.DataFrame.to_torch_iter_dataset daft.DataFrame.to_ray_dataset daft.DataFrame.to_dask_dataframe