diff --git a/daft/datasources.py b/daft/datasources.py index f14b7877c7..cd79972dbe 100644 --- a/daft/datasources.py +++ b/daft/datasources.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from enum import Enum +from daft.io import IOConfig + if sys.version_info < (3, 8): from typing_extensions import Protocol else: @@ -41,5 +43,9 @@ def scan_type(self): @dataclass(frozen=True) class ParquetSourceInfo(SourceInfo): + + use_native_downloader: bool + io_config: IOConfig | None + def scan_type(self): return StorageType.PARQUET diff --git a/daft/io/parquet.py b/daft/io/parquet.py index 3cd308219e..e9137bedf5 100644 --- a/daft/io/parquet.py +++ b/daft/io/parquet.py @@ -8,6 +8,7 @@ from daft.dataframe import DataFrame from daft.datasources import ParquetSourceInfo from daft.datatype import DataType +from daft.io import IOConfig from daft.io.common import _get_tabular_files_scan @@ -16,6 +17,8 @@ def read_parquet( path: Union[str, List[str]], schema_hints: Optional[Dict[str, DataType]] = None, fs: Optional[fsspec.AbstractFileSystem] = None, + io_config: Optional[IOConfig] = None, + use_native_downloader: bool = False, ) -> DataFrame: """Creates a DataFrame from Parquet file(s) @@ -31,6 +34,9 @@ def read_parquet( disable all schema inference on data being read, and throw an error if data being read is incompatible. fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for reading data. By default, Daft will automatically construct a FileSystem instance internally. + io_config (IOConfig): Config to be used with the native downloader + use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This + is currently experimental. returns: DataFrame: parsed DataFrame @@ -41,7 +47,10 @@ def read_parquet( plan = _get_tabular_files_scan( path, schema_hints, - ParquetSourceInfo(), + ParquetSourceInfo( + io_config=io_config, + use_native_downloader=use_native_downloader, + ), fs, ) return DataFrame(plan) diff --git a/daft/runners/runner_io.py b/daft/runners/runner_io.py index 5f10071271..869fbf1934 100644 --- a/daft/runners/runner_io.py +++ b/daft/runners/runner_io.py @@ -100,6 +100,8 @@ def sample_schema( return schema_inference.from_parquet( file=filepath, fs=fs, + io_config=source_info.io_config, + use_native_downloader=source_info.use_native_downloader, ) else: raise NotImplementedError(f"Schema inference for {source_info} not implemented") diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py index f8251ee71f..fb9de90a4b 100644 --- a/daft/table/schema_inference.py +++ b/daft/table/schema_inference.py @@ -9,6 +9,7 @@ from daft.datatype import DataType from daft.filesystem import _resolve_paths_and_filesystem +from daft.io import IOConfig from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions from daft.table import Table @@ -73,8 +74,18 @@ def from_json( def from_parquet( file: FileInput, fs: fsspec.AbstractFileSystem | None = None, + io_config: IOConfig | None = None, + use_native_downloader: bool = False, ) -> Schema: """Infers a Schema from a Parquet file""" + if use_native_downloader: + assert isinstance(file, (str, pathlib.Path)) + # TODO(sammy): [RUST-PARQUET] Implement getting a schema from a Parquet file + # return Schema.from_parquet(file, io_config=io_config) + raise NotImplementedError( + "Not implemented: use Rust native downloader to retrieve a Daft Schema from a Parquet file" + ) + if not isinstance(file, (str, pathlib.Path)): # BytesIO path. f = file diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 31f44be0f8..2eae36872d 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -16,6 +16,7 @@ from daft.expressions import ExpressionsProjection from daft.filesystem import _resolve_paths_and_filesystem +from daft.io import IOConfig from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions from daft.table import Table @@ -94,6 +95,8 @@ def read_parquet( schema: Schema, fs: fsspec.AbstractFileSystem | None = None, read_options: TableReadOptions = TableReadOptions(), + io_config: IOConfig | None = None, + use_native_downloader: bool = False, ) -> Table: """Reads a Table from a Parquet file @@ -106,6 +109,16 @@ def read_parquet( Returns: Table: Parsed Table from Parquet """ + if use_native_downloader: + assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_parquet" + return Table.read_parquet( + str(file), + columns=read_options.column_names, + # TODO(sammy): [RUST-PARQUET] Add API to limit number of rows read here, instead of rowgroups + # num_rows=read_options.num_rows, + io_config=io_config, + ) + f: IO if not isinstance(file, (str, pathlib.Path)): f = file