diff --git a/daft/filesystem.py b/daft/filesystem.py index e5a397e46d..da4170e490 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -27,6 +27,7 @@ ) from daft.datasources import ParquetSourceInfo, SourceInfo +from daft.table import Table _CACHED_FSES: dict[str, FileSystem] = {} @@ -329,10 +330,11 @@ def glob_path_with_stats( # Set number of rows if available. if isinstance(source_info, ParquetSourceInfo): if source_info.use_native_downloader: - # TODO(sammy): [RUST-PARQUET] - # file_metadata = get_parquet_metadata(list(filepaths_to_infos.keys()), io_config=source_info.io_config) - # ... (for now we only need `file_metadata[i].num_rows` to be valid) - raise NotImplementedError("[RUST-PARQUET] Implement batch read of metadata") + parquet_statistics = Table.read_parquet_statistics( + list(filepaths_to_infos.keys()), source_info.io_config + ).to_pydict() + for path, num_rows in zip(parquet_statistics["uris"], parquet_statistics["row_count"]): + filepaths_to_infos[path]["rows"] = num_rows else: parquet_metadatas = ThreadPoolExecutor().map(_get_parquet_metadata_single, filepaths_to_infos.keys()) for path, parquet_metadata in zip(filepaths_to_infos.keys(), parquet_metadatas): diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py index 196192371c..13bbea3ce1 100644 --- a/daft/table/schema_inference.py +++ b/daft/table/schema_inference.py @@ -81,11 +81,7 @@ def from_parquet( """Infers a Schema from a Parquet file""" if use_native_downloader: assert isinstance(file, (str, pathlib.Path)) - # TODO(sammy): [RUST-PARQUET] Implement getting a schema from a Parquet file - # return get_parquet_metadata([file], io_config=io_config)[0].get_daft_schema() - raise NotImplementedError( - "Not implemented: use Rust native downloader to retrieve a Daft Schema from a Parquet file" - ) + return Schema.from_parquet(str(file), io_config=io_config) if not isinstance(file, (str, pathlib.Path)): # BytesIO path. diff --git a/daft/table/table.py b/daft/table/table.py index 4ceb225bbe..7e82b11d72 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -364,6 +364,6 @@ def read_parquet_statistics( io_config: IOConfig | None = None, ) -> Table: if not isinstance(paths, Series): - paths = Series.from_pylist(paths) - + paths = Series.from_pylist(paths, name="uris") + assert paths.name() == "uris", f"Expected input series to have name 'uris', but found: {paths.name()}" return Table._from_pytable(_read_parquet_statistics(uris=paths._series, io_config=io_config)) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 896c09db44..67bca130ee 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -116,8 +116,7 @@ def read_parquet( return Table.read_parquet( str(file), columns=read_options.column_names, - # TODO(sammy): [RUST-PARQUET] Add API to limit number of rows read here, instead of rowgroups - # num_rows=read_options.num_rows, + num_rows=read_options.num_rows, io_config=io_config, )