diff --git a/Cargo.lock b/Cargo.lock index 6c3af7a807..1eb62600c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -936,6 +936,7 @@ dependencies = [ "pyo3-log", "reqwest", "serde", + "serde_json", "snafu", "tempfile", "tokio", diff --git a/daft/datasources.py b/daft/datasources.py index f14b7877c7..2d396f96e2 100644 --- a/daft/datasources.py +++ b/daft/datasources.py @@ -3,6 +3,10 @@ import sys from dataclasses import dataclass from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from daft.io import IOConfig if sys.version_info < (3, 8): from typing_extensions import Protocol @@ -41,5 +45,9 @@ def scan_type(self): @dataclass(frozen=True) class ParquetSourceInfo(SourceInfo): + + use_native_downloader: bool + io_config: IOConfig | None + def scan_type(self): return StorageType.PARQUET diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 247f56fb50..e7a2c3f053 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -400,6 +400,8 @@ def _handle_tabular_files_scan( schema=schema, fs=fs, read_options=read_options, + io_config=scan._source_info.io_config, + use_native_downloader=scan._source_info.use_native_downloader, ) for fp in filepaths ] diff --git a/daft/filesystem.py b/daft/filesystem.py index c8b7380ccc..da4170e490 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -27,6 +27,7 @@ ) from daft.datasources import ParquetSourceInfo, SourceInfo +from daft.table import Table _CACHED_FSES: dict[str, FileSystem] = {} @@ -328,9 +329,16 @@ def glob_path_with_stats( # Set number of rows if available. if isinstance(source_info, ParquetSourceInfo): - parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys()) - for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas): - filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows + if source_info.use_native_downloader: + parquet_statistics = Table.read_parquet_statistics( + list(filepaths_to_infos.keys()), source_info.io_config + ).to_pydict() + for path, num_rows in zip(parquet_statistics["uris"], parquet_statistics["row_count"]): + filepaths_to_infos[path]["rows"] = num_rows + else: + parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys()) + for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas): + filepaths_to_infos[path]["rows"] = parquet_metadata.num_rows return [ ListingInfo(path=_ensure_path_protocol(protocol, path), **infos) for path, infos in filepaths_to_infos.items() diff --git a/daft/io/__init__.py b/daft/io/__init__.py index d3558b155c..22ac81e161 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -1,9 +1,8 @@ from __future__ import annotations -from daft.daft import PyIOConfig as IOConfig -from daft.daft import PyS3Config as S3Config from daft.io._csv import read_csv from daft.io._json import read_json +from daft.io.config import IOConfig, S3Config from daft.io.file_path import from_glob_path from daft.io.parquet import read_parquet diff --git a/daft/io/config.py b/daft/io/config.py new file mode 100644 index 0000000000..82b2f3ef2a --- /dev/null +++ b/daft/io/config.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import json + +from daft.daft import PyIOConfig as IOConfig +from daft.daft import PyS3Config as S3Config + + +def _io_config_from_json(io_config_json: str) -> IOConfig: + """Used when deserializing a serialized IOConfig object""" + data = json.loads(io_config_json) + s3_config = S3Config(**data["s3"]) if "s3" in data else None + return IOConfig(s3=s3_config) diff --git a/daft/io/parquet.py b/daft/io/parquet.py index 3cd308219e..06e9b51c83 100644 --- a/daft/io/parquet.py +++ b/daft/io/parquet.py @@ -1,6 +1,6 @@ # isort: dont-add-import: from __future__ import annotations -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import fsspec @@ -10,12 +10,17 @@ from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan +if TYPE_CHECKING: + from daft.io import IOConfig + @PublicAPI def read_parquet( path: Union[str, List[str]], schema_hints: Optional[Dict[str, DataType]] = None, fs: Optional[fsspec.AbstractFileSystem] = None, + io_config: Optional["IOConfig"] = None, + use_native_downloader: bool = False, ) -> DataFrame: """Creates a DataFrame from Parquet file(s) @@ -31,6 +36,9 @@ def read_parquet( disable all schema inference on data being read, and throw an error if data being read is incompatible. fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for reading data. By default, Daft will automatically construct a FileSystem instance internally. + io_config (IOConfig): Config to be used with the native downloader + use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This + is currently experimental. returns: DataFrame: parsed DataFrame @@ -41,7 +49,10 @@ def read_parquet( plan = _get_tabular_files_scan( path, schema_hints, - ParquetSourceInfo(), + ParquetSourceInfo( + io_config=io_config, + use_native_downloader=use_native_downloader, + ), fs, ) return DataFrame(plan) diff --git a/daft/runners/runner_io.py b/daft/runners/runner_io.py index 5f10071271..869fbf1934 100644 --- a/daft/runners/runner_io.py +++ b/daft/runners/runner_io.py @@ -100,6 +100,8 @@ def sample_schema( return schema_inference.from_parquet( file=filepath, fs=fs, + io_config=source_info.io_config, + use_native_downloader=source_info.use_native_downloader, ) else: raise NotImplementedError(f"Schema inference for {source_info} not implemented") diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py index f8251ee71f..13bbea3ce1 100644 --- a/daft/table/schema_inference.py +++ b/daft/table/schema_inference.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: import fsspec + from daft.io import IOConfig + def from_csv( file: FileInput, @@ -73,8 +75,14 @@ def from_json( def from_parquet( file: FileInput, fs: fsspec.AbstractFileSystem | None = None, + io_config: IOConfig | None = None, + use_native_downloader: bool = False, ) -> Schema: """Infers a Schema from a Parquet file""" + if use_native_downloader: + assert isinstance(file, (str, pathlib.Path)) + return Schema.from_parquet(str(file), io_config=io_config) + if not isinstance(file, (str, pathlib.Path)): # BytesIO path. f = file diff --git a/daft/table/table.py b/daft/table/table.py index 4ceb225bbe..7e82b11d72 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -364,6 +364,6 @@ def read_parquet_statistics( io_config: IOConfig | None = None, ) -> Table: if not isinstance(paths, Series): - paths = Series.from_pylist(paths) - + paths = Series.from_pylist(paths, name="uris") + assert paths.name() == "uris", f"Expected input series to have name 'uris', but found: {paths.name()}" return Table._from_pytable(_read_parquet_statistics(uris=paths._series, io_config=io_config)) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 31f44be0f8..2ad10ac9a8 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -3,7 +3,7 @@ import contextlib import pathlib from collections.abc import Generator -from typing import IO, Union +from typing import IO, TYPE_CHECKING, Union from uuid import uuid4 import fsspec @@ -20,6 +20,9 @@ from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions from daft.table import Table +if TYPE_CHECKING: + from daft.io import IOConfig + FileInput = Union[pathlib.Path, str, IO[bytes]] @@ -94,6 +97,8 @@ def read_parquet( schema: Schema, fs: fsspec.AbstractFileSystem | None = None, read_options: TableReadOptions = TableReadOptions(), + io_config: IOConfig | None = None, + use_native_downloader: bool = False, ) -> Table: """Reads a Table from a Parquet file @@ -106,6 +111,16 @@ def read_parquet( Returns: Table: Parsed Table from Parquet """ + if use_native_downloader: + assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_parquet" + tbl = Table.read_parquet( + str(file), + columns=read_options.column_names, + num_rows=read_options.num_rows, + io_config=io_config, + ) + return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) + f: IO if not isinstance(file, (str, pathlib.Path)): f = file diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index fc393cdde7..b9321c0d4a 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -15,6 +15,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} serde = {workspace = true} +serde_json = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} url = "2.4.0" diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 105a30ba3c..07383dd6b1 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -1,4 +1,5 @@ use crate::config::{IOConfig, S3Config}; +use common_error::DaftError; use pyo3::prelude::*; #[derive(Clone, Default)] @@ -34,6 +35,17 @@ impl PyIOConfig { config: self.config.s3.clone(), }) } + + pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (String,))> { + let io_config_module = py.import("daft.io.config")?; + let json_string = serde_json::to_string(&self.config).map_err(DaftError::from)?; + Ok(( + io_config_module + .getattr("_io_config_from_json")? + .to_object(py), + (json_string,), + )) + } } #[pymethods] diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 1fcd915b55..6144b6170e 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -638,13 +638,14 @@ def test_create_dataframe_json_specify_schema(valid_data: list[dict[str, float]] ### -def test_create_dataframe_parquet(valid_data: list[dict[str, float]]) -> None: +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_create_dataframe_parquet(valid_data: list[dict[str, float]], use_native_downloader) -> None: with tempfile.NamedTemporaryFile("w") as f: table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES}) papq.write_table(table, f.name) f.flush() - df = daft.read_parquet(f.name) + df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -652,14 +653,15 @@ def test_create_dataframe_parquet(valid_data: list[dict[str, float]]) -> None: assert len(pd_df) == len(valid_data) -def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]]) -> None: +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]], use_native_downloader) -> None: with tempfile.NamedTemporaryFile("w") as f1, tempfile.NamedTemporaryFile("w") as f2: for f in (f1, f2): table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES}) papq.write_table(table, f.name) f.flush() - df = daft.read_parquet([f1.name, f2.name]) + df = daft.read_parquet([f1.name, f2.name], use_native_downloader=use_native_downloader) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -697,7 +699,8 @@ def test_create_dataframe_parquet_custom_fs(valid_data: list[dict[str, float]]) assert len(pd_df) == len(valid_data) -def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]]) -> None: +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, float]], use_native_downloader) -> None: with tempfile.NamedTemporaryFile("w") as f: table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES}) papq.write_table(table, f.name) @@ -705,7 +708,7 @@ def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, f col_subset = COL_NAMES[:3] - df = daft.read_parquet(f.name) + df = daft.read_parquet(f.name, use_native_downloader=use_native_downloader) df = df.select(*col_subset) assert df.column_names == col_subset @@ -714,7 +717,8 @@ def test_create_dataframe_parquet_column_projection(valid_data: list[dict[str, f assert len(pd_df) == len(valid_data) -def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]]) -> None: +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, float]], use_native_downloader) -> None: with tempfile.NamedTemporaryFile("w") as f: table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES}) papq.write_table(table, f.name) @@ -729,6 +733,7 @@ def test_create_dataframe_parquet_specify_schema(valid_data: list[dict[str, floa "petal_width": DataType.float32(), "variety": DataType.string(), }, + use_native_downloader=use_native_downloader, ) assert df.column_names == COL_NAMES diff --git a/tests/integration/io/parquet/test_remote_reads.py b/tests/integration/io/parquet/test_remote_reads.py index d933992f4b..95463abe75 100644 --- a/tests/integration/io/parquet/test_remote_reads.py +++ b/tests/integration/io/parquet/test_remote_reads.py @@ -5,6 +5,7 @@ import pytest from pyarrow import parquet as pq +import daft from daft.filesystem import get_filesystem_from_path, get_protocol_from_path from daft.io import IOConfig, S3Config from daft.table import Table @@ -177,9 +178,20 @@ def read_parquet_with_pyarrow(path) -> pa.Table: @pytest.mark.integration() -def test_parquet_read(parquet_file): +def test_parquet_read_table(parquet_file): _, url = parquet_file daft_native_read = Table.read_parquet(url, io_config=IOConfig(s3=S3Config(anonymous=True))) pa_read = Table.from_arrow(read_parquet_with_pyarrow(url)) assert daft_native_read.schema() == pa_read.schema() pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) + + +@pytest.mark.integration() +def test_parquet_read_df(parquet_file): + _, url = parquet_file + daft_native_read = daft.read_parquet( + url, io_config=IOConfig(s3=S3Config(anonymous=True)), use_native_downloader=True + ) + pa_read = Table.from_arrow(read_parquet_with_pyarrow(url)) + assert daft_native_read.schema() == pa_read.schema() + pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())