Skip to content

Commit

Permalink
[FEAT] Add DataFrame.to_torch_map_dataset and `.to_torch_iter_datas…
Browse files Browse the repository at this point in the history
…et`. (#1086)

Adds two new top-level APIs to DataFrame: to_torch_map_dataset and
to_torch_iter_dataset, that returns respective PyTorch datasets (TODO:
make these DataPipes).

- `to_torch_map_dataset` will execute the whole dataframe before
returning, since it needs random access.
- `to_torch_iter_dataset` will return immediately; results are returned
via streaming execution.

Both are only meant for use in a single-node setting, with Ray Datasets
being the recommended data loading abstraction for distributed training.

---------

Co-authored-by: Xiayue Charles Lin <[email protected]>
  • Loading branch information
xcharleslin and Xiayue Charles Lin authored Jun 23, 2023
1 parent 978db0f commit e101ed6
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
49 changes: 49 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
if TYPE_CHECKING:
from ray.data.dataset import Dataset as RayDataset
from ray import ObjectRef as RayObjectRef
import torch.utils.data.Dataset as TorchDataset
import torch.utils.data.IterableDataset as TorchIterableDataset
import pandas as pd
import pyarrow as pa
import dask
Expand Down Expand Up @@ -1116,6 +1118,53 @@ def to_pydict(self) -> Dict[str, List[Any]]:
assert result is not None
return result.to_pydict()

@DataframePublicAPI
def to_torch_map_dataset(self) -> "TorchDataset":
"""Convert the current DataFrame into a map-style
`Torch Dataset <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__
for use with PyTorch.
This method will materialize the entire DataFrame and block on completion.
Items will be returned in pydict format: a dict of `{"column name": value}` for each row in the data.
.. NOTE::
If you do not need random access, you may get better performance out of an IterableDataset,
which streams data items in as soon as they are ready and does not block on full materialization.
.. NOTE::
This method returns results locally.
For distributed training, you may want to use ``DataFrame.to_ray_dataset()``.
"""
from daft.dataframe.to_torch import DaftTorchDataset

return DaftTorchDataset(self.to_pydict(), len(self))

@DataframePublicAPI
def to_torch_iter_dataset(self) -> "TorchIterableDataset":
"""Convert the current DataFrame into a
`Torch IterableDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset>`__
for use with PyTorch.
Begins execution of the DataFrame if it is not yet executed.
Items will be returned in pydict format: a dict of `{"column name": value}` for each row in the data.
.. NOTE::
The produced dataset is meant to be used with the single-process DataLoader,
and does not support data sharding hooks for multi-process data loading.
Do keep in mind that Daft is already using multithreading or multiprocessing under the hood
to compute the data stream that feeds this dataset.
.. NOTE::
This method returns results locally.
For distributed training, you may want to use ``DataFrame.to_ray_dataset()``.
"""
from daft.dataframe.to_torch import DaftTorchIterableDataset

return DaftTorchIterableDataset(self)

@DataframePublicAPI
def to_ray_dataset(self) -> "RayDataset":
"""Converts the current DataFrame to a Ray Dataset which is useful for running distributed ML model training in Ray
Expand Down
45 changes: 45 additions & 0 deletions daft/dataframe/to_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import Any, Iterable, Iterator

from loguru import logger

try:
# When available, subclass from the newer torchdata DataPipes instead of torch Datasets.
import torchdata

MAP_DATASET_CLASS = torchdata.datapipes.map.MapDataPipe
ITER_DATASET_CLASS = torchdata.datapipes.iter.IterDataPipe
except ImportError:
try:
import torch

MAP_DATASET_CLASS = torch.utils.data.Dataset
ITER_DATASET_CLASS = torch.utils.data.IterableDataset
except ImportError:
logger.error(f"Error when importing Torch. To use PyTorch features, please install torch.")
raise


class DaftTorchDataset(MAP_DATASET_CLASS): # type: ignore
"""A wrapper to create a torch map-style Dataset from a Daft pydict of items."""

def __init__(self, data: dict[str, list[Any]], length: int):
self.data = data
self.length = length

def __len__(self):
return self.length

def __getitem__(self, i):
return {key: vallist[i] for (key, vallist) in self.data.items()}


class DaftTorchIterableDataset(ITER_DATASET_CLASS): # type: ignore
"""A thin wrapper to create a torch IterableDataset from an iterable."""

def __init__(self, iterable: Iterable[dict[str, Any]]):
self.iterable = iterable

def __iter__(self) -> Iterator[dict[str, Any]]:
return iter(self.iterable)
4 changes: 3 additions & 1 deletion docs/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Combining
:toctree: doc_gen/dataframe_methods

daft.DataFrame.join
daft.DataFrane.concat
daft.DataFrame.concat

.. _df-aggregations:

Expand Down Expand Up @@ -143,6 +143,8 @@ Integrations
:toctree: doc_gen/dataframe_methods

daft.DataFrame.to_pandas
daft.DataFrame.to_torch_map_dataset
daft.DataFrame.to_torch_iter_dataset
daft.DataFrame.to_ray_dataset
daft.DataFrame.to_dask_dataframe

Expand Down

0 comments on commit e101ed6

Please sign in to comment.