Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable feature-flagged native downloader in daft.read_parquet #1190

Merged
merged 13 commits into from
Jul 31, 2023
Merged
8 changes: 8 additions & 0 deletions daft/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import sys
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from daft.io import IOConfig

Check warning on line 9 in daft/datasources.py

View check run for this annotation

Codecov / codecov/patch

daft/datasources.py#L9

Added line #L9 was not covered by tests

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand Down Expand Up @@ -41,5 +45,9 @@

@dataclass(frozen=True)
class ParquetSourceInfo(SourceInfo):

use_native_downloader: bool
io_config: IOConfig | None

def scan_type(self):
return StorageType.PARQUET
14 changes: 11 additions & 3 deletions daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from daft.datasources import ParquetSourceInfo, SourceInfo
from daft.table import Table

_CACHED_FSES: dict[str, FileSystem] = {}

Expand Down Expand Up @@ -328,9 +329,16 @@ def glob_path_with_stats(

# Set number of rows if available.
if isinstance(source_info, ParquetSourceInfo):
parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys())
for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas):
filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows
if source_info.use_native_downloader:
parquet_statistics = Table.read_parquet_statistics(
list(filepaths_to_infos.keys()), source_info.io_config
).to_pydict()
for path, num_rows in zip(parquet_statistics["uris"], parquet_statistics["row_count"]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: This code relies on hardcoded strings which are the names of the table's series. A little error prone perhaps?

filepaths_to_infos[path]["rows"] = num_rows
else:
parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys())
for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas):
filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows

return [
ListingInfo(path=_ensure_path_protocol(protocol, path), **infos) for path, infos in filepaths_to_infos.items()
Expand Down
15 changes: 13 additions & 2 deletions daft/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# isort: dont-add-import: from __future__ import annotations

from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import fsspec

Expand All @@ -10,12 +10,17 @@
from daft.datatype import DataType
from daft.io.common import _get_tabular_files_scan

if TYPE_CHECKING:
from daft.io import IOConfig

Check warning on line 14 in daft/io/parquet.py

View check run for this annotation

Codecov / codecov/patch

daft/io/parquet.py#L14

Added line #L14 was not covered by tests


@PublicAPI
def read_parquet(
path: Union[str, List[str]],
schema_hints: Optional[Dict[str, DataType]] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
io_config: Optional["IOConfig"] = None,
use_native_downloader: bool = False,
) -> DataFrame:
"""Creates a DataFrame from Parquet file(s)

Expand All @@ -31,6 +36,9 @@
disable all schema inference on data being read, and throw an error if data being read is incompatible.
fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for reading data.
By default, Daft will automatically construct a FileSystem instance internally.
io_config (IOConfig): Config to be used with the native downloader
use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This
is currently experimental.

returns:
DataFrame: parsed DataFrame
Expand All @@ -41,7 +49,10 @@
plan = _get_tabular_files_scan(
path,
schema_hints,
ParquetSourceInfo(),
ParquetSourceInfo(
io_config=io_config,
use_native_downloader=use_native_downloader,
),
fs,
)
return DataFrame(plan)
2 changes: 2 additions & 0 deletions daft/runners/runner_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def sample_schema(
return schema_inference.from_parquet(
file=filepath,
fs=fs,
io_config=source_info.io_config,
use_native_downloader=source_info.use_native_downloader,
)
else:
raise NotImplementedError(f"Schema inference for {source_info} not implemented")
8 changes: 8 additions & 0 deletions daft/table/schema_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
if TYPE_CHECKING:
import fsspec

from daft.io import IOConfig

Check warning on line 20 in daft/table/schema_inference.py

View check run for this annotation

Codecov / codecov/patch

daft/table/schema_inference.py#L20

Added line #L20 was not covered by tests


def from_csv(
file: FileInput,
Expand Down Expand Up @@ -73,8 +75,14 @@
def from_parquet(
file: FileInput,
fs: fsspec.AbstractFileSystem | None = None,
io_config: IOConfig | None = None,
use_native_downloader: bool = False,
) -> Schema:
"""Infers a Schema from a Parquet file"""
if use_native_downloader:
assert isinstance(file, (str, pathlib.Path))
return Schema.from_parquet(str(file), io_config=io_config)

if not isinstance(file, (str, pathlib.Path)):
# BytesIO path.
f = file
Expand Down
4 changes: 2 additions & 2 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,6 @@ def read_parquet_statistics(
io_config: IOConfig | None = None,
) -> Table:
if not isinstance(paths, Series):
paths = Series.from_pylist(paths)

paths = Series.from_pylist(paths, name="uris")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: Our Series objects default to the name "list_col". Modified here to have a better name, but needs to be consistent across calls and so I added an assert after.

assert paths.name() == "uris", f"Expected input series to have name 'uris', but found: {paths.name()}"
return Table._from_pytable(_read_parquet_statistics(uris=paths._series, io_config=io_config))
16 changes: 15 additions & 1 deletion daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import pathlib
from collections.abc import Generator
from typing import IO, Union
from typing import IO, TYPE_CHECKING, Union
from uuid import uuid4

import fsspec
Expand All @@ -20,6 +20,9 @@
from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions
from daft.table import Table

if TYPE_CHECKING:
from daft.io import IOConfig

Check warning on line 24 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L24

Added line #L24 was not covered by tests

FileInput = Union[pathlib.Path, str, IO[bytes]]


Expand Down Expand Up @@ -94,6 +97,8 @@
schema: Schema,
fs: fsspec.AbstractFileSystem | None = None,
read_options: TableReadOptions = TableReadOptions(),
io_config: IOConfig | None = None,
use_native_downloader: bool = False,
) -> Table:
"""Reads a Table from a Parquet file

Expand All @@ -106,6 +111,15 @@
Returns:
Table: Parsed Table from Parquet
"""
if use_native_downloader:
assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_parquet"
return Table.read_parquet(

Check warning on line 116 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L115-L116

Added lines #L115 - L116 were not covered by tests
str(file),
columns=read_options.column_names,
num_rows=read_options.num_rows,
io_config=io_config,
)

f: IO
if not isinstance(file, (str, pathlib.Path)):
f = file
Expand Down
19 changes: 12 additions & 7 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,28 +638,30 @@ def test_create_dataframe_json_specify_schema(valid_data: list[dict[str, float]]
###


def test_create_dataframe_parquet(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

df = daft.read_parquet(f.name)
df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader)
assert df.column_names == COL_NAMES

pd_df = df.to_pandas()
assert list(pd_df.columns) == COL_NAMES
assert len(pd_df) == len(valid_data)


def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f1, tempfile.NamedTemporaryFile("w") as f2:
for f in (f1, f2):
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

df = daft.read_parquet([f1.name, f2.name])
df = daft.read_parquet([f1.name, f2.name], use_native_downloader=use_native_downloader)
assert df.column_names == COL_NAMES

pd_df = df.to_pandas()
Expand Down Expand Up @@ -697,15 +699,16 @@ def test_create_dataframe_parquet_custom_fs(valid_data: list[dict[str, float]])
assert len(pd_df) == len(valid_data)


def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

col_subset = COL_NAMES[:3]

df = daft.read_parquet(f.name)
df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader)
df = df.select(*col_subset)
assert df.column_names == col_subset

Expand All @@ -714,7 +717,8 @@ def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, f
assert len(pd_df) == len(valid_data)


def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
Expand All @@ -729,6 +733,7 @@ def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, floa
"petal_width": DataType.float32(),
"variety": DataType.string(),
},
use_native_downloader=use_native_downloader,
)
assert df.column_names == COL_NAMES

Expand Down
Loading