Skip to content

Commit

Permalink
[FEAT] Enable feature-flagged native downloader in daft.read_parquet (#…
Browse files Browse the repository at this point in the history
…1190)

Enables changes to `daft.read_parquet` such that when we specify
`use_native_downloader=True`, this uses our new Rust-based native
Parquet downloading and parsing for:

1. Schema inference
2. Retrieving per-file metadata (currently only the number of rows)
3. Reading the actual `Table`

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia committed Jul 31, 2023
1 parent bacd70e commit f43174d
Show file tree
Hide file tree
Showing 15 changed files with 115 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions daft/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import sys
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from daft.io import IOConfig

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand Down Expand Up @@ -41,5 +45,9 @@ def scan_type(self):

@dataclass(frozen=True)
class ParquetSourceInfo(SourceInfo):

use_native_downloader: bool
io_config: IOConfig | None

def scan_type(self):
return StorageType.PARQUET
2 changes: 2 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def _handle_tabular_files_scan(
schema=schema,
fs=fs,
read_options=read_options,
io_config=scan._source_info.io_config,
use_native_downloader=scan._source_info.use_native_downloader,
)
for fp in filepaths
]
Expand Down
14 changes: 11 additions & 3 deletions daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from daft.datasources import ParquetSourceInfo, SourceInfo
from daft.table import Table

_CACHED_FSES: dict[str, FileSystem] = {}

Expand Down Expand Up @@ -328,9 +329,16 @@ def glob_path_with_stats(

# Set number of rows if available.
if isinstance(source_info, ParquetSourceInfo):
parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys())
for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas):
filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows
if source_info.use_native_downloader:
parquet_statistics = Table.read_parquet_statistics(
list(filepaths_to_infos.keys()), source_info.io_config
).to_pydict()
for path, num_rows in zip(parquet_statistics["uris"], parquet_statistics["row_count"]):
filepaths_to_infos[path]["rows"] = num_rows
else:
parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys())
for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas):
filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows

return [
ListingInfo(path=_ensure_path_protocol(protocol, path), **infos) for path, infos in filepaths_to_infos.items()
Expand Down
3 changes: 1 addition & 2 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from daft.daft import PyIOConfig as IOConfig
from daft.daft import PyS3Config as S3Config
from daft.io._csv import read_csv
from daft.io._json import read_json
from daft.io.config import IOConfig, S3Config
from daft.io.file_path import from_glob_path
from daft.io.parquet import read_parquet

Expand Down
13 changes: 13 additions & 0 deletions daft/io/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

import json

from daft.daft import PyIOConfig as IOConfig
from daft.daft import PyS3Config as S3Config


def _io_config_from_json(io_config_json: str) -> IOConfig:
"""Used when deserializing a serialized IOConfig object"""
data = json.loads(io_config_json)
s3_config = S3Config(**data["s3"]) if "s3" in data else None
return IOConfig(s3=s3_config)
15 changes: 13 additions & 2 deletions daft/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# isort: dont-add-import: from __future__ import annotations

from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import fsspec

Expand All @@ -10,12 +10,17 @@
from daft.datatype import DataType
from daft.io.common import _get_tabular_files_scan

if TYPE_CHECKING:
from daft.io import IOConfig


@PublicAPI
def read_parquet(
path: Union[str, List[str]],
schema_hints: Optional[Dict[str, DataType]] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
io_config: Optional["IOConfig"] = None,
use_native_downloader: bool = False,
) -> DataFrame:
"""Creates a DataFrame from Parquet file(s)
Expand All @@ -31,6 +36,9 @@ def read_parquet(
disable all schema inference on data being read, and throw an error if data being read is incompatible.
fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for reading data.
By default, Daft will automatically construct a FileSystem instance internally.
io_config (IOConfig): Config to be used with the native downloader
use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This
is currently experimental.
returns:
DataFrame: parsed DataFrame
Expand All @@ -41,7 +49,10 @@ def read_parquet(
plan = _get_tabular_files_scan(
path,
schema_hints,
ParquetSourceInfo(),
ParquetSourceInfo(
io_config=io_config,
use_native_downloader=use_native_downloader,
),
fs,
)
return DataFrame(plan)
2 changes: 2 additions & 0 deletions daft/runners/runner_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def sample_schema(
return schema_inference.from_parquet(
file=filepath,
fs=fs,
io_config=source_info.io_config,
use_native_downloader=source_info.use_native_downloader,
)
else:
raise NotImplementedError(f"Schema inference for {source_info} not implemented")
8 changes: 8 additions & 0 deletions daft/table/schema_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
if TYPE_CHECKING:
import fsspec

from daft.io import IOConfig


def from_csv(
file: FileInput,
Expand Down Expand Up @@ -73,8 +75,14 @@ def from_json(
def from_parquet(
file: FileInput,
fs: fsspec.AbstractFileSystem | None = None,
io_config: IOConfig | None = None,
use_native_downloader: bool = False,
) -> Schema:
"""Infers a Schema from a Parquet file"""
if use_native_downloader:
assert isinstance(file, (str, pathlib.Path))
return Schema.from_parquet(str(file), io_config=io_config)

if not isinstance(file, (str, pathlib.Path)):
# BytesIO path.
f = file
Expand Down
4 changes: 2 additions & 2 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,6 @@ def read_parquet_statistics(
io_config: IOConfig | None = None,
) -> Table:
if not isinstance(paths, Series):
paths = Series.from_pylist(paths)

paths = Series.from_pylist(paths, name="uris")
assert paths.name() == "uris", f"Expected input series to have name 'uris', but found: {paths.name()}"
return Table._from_pytable(_read_parquet_statistics(uris=paths._series, io_config=io_config))
17 changes: 16 additions & 1 deletion daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import pathlib
from collections.abc import Generator
from typing import IO, Union
from typing import IO, TYPE_CHECKING, Union
from uuid import uuid4

import fsspec
Expand All @@ -20,6 +20,9 @@
from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions
from daft.table import Table

if TYPE_CHECKING:
from daft.io import IOConfig

FileInput = Union[pathlib.Path, str, IO[bytes]]


Expand Down Expand Up @@ -94,6 +97,8 @@ def read_parquet(
schema: Schema,
fs: fsspec.AbstractFileSystem | None = None,
read_options: TableReadOptions = TableReadOptions(),
io_config: IOConfig | None = None,
use_native_downloader: bool = False,
) -> Table:
"""Reads a Table from a Parquet file
Expand All @@ -106,6 +111,16 @@ def read_parquet(
Returns:
Table: Parsed Table from Parquet
"""
if use_native_downloader:
assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_parquet"
tbl = Table.read_parquet(
str(file),
columns=read_options.column_names,
num_rows=read_options.num_rows,
io_config=io_config,
)
return _cast_table_to_schema(tbl, read_options=read_options, schema=schema)

f: IO
if not isinstance(file, (str, pathlib.Path)):
f = file
Expand Down
1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
pyo3-log = {workspace = true, optional = true}
serde = {workspace = true}
serde_json = {workspace = true}
snafu = {workspace = true}
tokio = {workspace = true}
url = "2.4.0"
Expand Down
12 changes: 12 additions & 0 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::config::{IOConfig, S3Config};
use common_error::DaftError;
use pyo3::prelude::*;

#[derive(Clone, Default)]
Expand Down Expand Up @@ -34,6 +35,17 @@ impl PyIOConfig {
config: self.config.s3.clone(),
})
}

pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (String,))> {
let io_config_module = py.import("daft.io.config")?;
let json_string = serde_json::to_string(&self.config).map_err(DaftError::from)?;
Ok((
io_config_module
.getattr("_io_config_from_json")?
.to_object(py),
(json_string,),
))
}
}

#[pymethods]
Expand Down
19 changes: 12 additions & 7 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,28 +638,30 @@ def test_create_dataframe_json_specify_schema(valid_data: list[dict[str, float]]
###


def test_create_dataframe_parquet(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

df = daft.read_parquet(f.name)
df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader)
assert df.column_names == COL_NAMES

pd_df = df.to_pandas()
assert list(pd_df.columns) == COL_NAMES
assert len(pd_df) == len(valid_data)


def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f1, tempfile.NamedTemporaryFile("w") as f2:
for f in (f1, f2):
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

df = daft.read_parquet([f1.name, f2.name])
df = daft.read_parquet([f1.name, f2.name], use_native_downloader=use_native_downloader)
assert df.column_names == COL_NAMES

pd_df = df.to_pandas()
Expand Down Expand Up @@ -697,15 +699,16 @@ def test_create_dataframe_parquet_custom_fs(valid_data: list[dict[str, float]])
assert len(pd_df) == len(valid_data)


def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

col_subset = COL_NAMES[:3]

df = daft.read_parquet(f.name)
df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader)
df = df.select(*col_subset)
assert df.column_names == col_subset

Expand All @@ -714,7 +717,8 @@ def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, f
assert len(pd_df) == len(valid_data)


def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]]) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]], use_native_downloader) -> None:
with tempfile.NamedTemporaryFile("w") as f:
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
Expand All @@ -729,6 +733,7 @@ def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, floa
"petal_width": DataType.float32(),
"variety": DataType.string(),
},
use_native_downloader=use_native_downloader,
)
assert df.column_names == COL_NAMES

Expand Down
14 changes: 13 additions & 1 deletion tests/integration/io/parquet/test_remote_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pyarrow import parquet as pq

import daft
from daft.filesystem import get_filesystem_from_path, get_protocol_from_path
from daft.io import IOConfig, S3Config
from daft.table import Table
Expand Down Expand Up @@ -177,9 +178,20 @@ def read_parquet_with_pyarrow(path) -> pa.Table:


@pytest.mark.integration()
def test_parquet_read(parquet_file):
def test_parquet_read_table(parquet_file):
_, url = parquet_file
daft_native_read = Table.read_parquet(url, io_config=IOConfig(s3=S3Config(anonymous=True)))
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))
assert daft_native_read.schema() == pa_read.schema()
pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())


@pytest.mark.integration()
def test_parquet_read_df(parquet_file):
_, url = parquet_file
daft_native_read = daft.read_parquet(
url, io_config=IOConfig(s3=S3Config(anonymous=True)), use_native_downloader=True
)
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))
assert daft_native_read.schema() == pa_read.schema()
pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())

0 comments on commit f43174d

Please sign in to comment.