diff --git a/benchmarking/parquet/conftest.py b/benchmarking/parquet/conftest.py index f8062bf5d5..8172a524e2 100644 --- a/benchmarking/parquet/conftest.py +++ b/benchmarking/parquet/conftest.py @@ -60,3 +60,40 @@ def daft_native_read(path: str, columns: list[str] | None = None) -> pa.Table: def read_fn(request): """Fixture which returns the function to read a PyArrow table from a path""" return request.param + + +def bulk_read_adapter(func): + def fn(files: list[str]) -> list[pa.Table]: + return [func(f) for f in files] + + return fn + + +def daft_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[pa.Table]: + tables = daft.table.Table.read_parquet_bulk(paths, columns=columns) + return [t.to_arrow() for t in tables] + + +def pyarrow_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[pa.Table]: + return [pyarrow_read(f, columns=columns) for f in paths] + + +def boto_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[pa.Table]: + return [boto3_get_object_read(f, columns=columns) for f in paths] + + +@pytest.fixture( + params=[ + daft_bulk_read, + pyarrow_bulk_read, + boto_bulk_read, + ], + ids=[ + "daft_bulk_read", + "pyarrow_bulk_read", + "boto3_bulk_read", + ], +) +def bulk_read_fn(request): + """Fixture which returns the function to read a PyArrow table from a path""" + return request.param diff --git a/benchmarking/parquet/test_bulk_reads.py b/benchmarking/parquet/test_bulk_reads.py new file mode 100644 index 0000000000..06712296e7 --- /dev/null +++ b/benchmarking/parquet/test_bulk_reads.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import pytest + +PATH = ( + "s3://eventual-dev-benchmarking-fixtures/parquet-benchmarking/tpch/200MB-2RG/daft_200MB_lineitem_chunk.RG-2.parquet" +) + + +@pytest.mark.benchmark(group="num_files_single_column") +@pytest.mark.parametrize( + "num_files", + [1, 2, 4, 8], +) +def test_read_parquet_num_files_single_column(num_files, bulk_read_fn, benchmark): + data = benchmark(bulk_read_fn, [PATH] * num_files, columns=["L_ORDERKEY"]) + assert len(data) == num_files + # Make sure the data is correct + for i in range(num_files): + assert data[i].column_names == ["L_ORDERKEY"] + assert len(data[i]) == 5515199 + + +@pytest.mark.benchmark(group="num_rowgroups_all_columns") +@pytest.mark.parametrize( + "num_files", + [1, 2, 4], +) +def test_read_parquet_num_files_all_columns(num_files, bulk_read_fn, benchmark): + data = benchmark(bulk_read_fn, [PATH] * num_files) + assert len(data) == num_files + + # Make sure the data is correct + for i in range(num_files): + assert len(data[i].column_names) == 16 + assert len(data[i]) == 5515199 diff --git a/daft/table/table.py b/daft/table/table.py index 9022361440..63dd4171ea 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -7,7 +7,8 @@ from daft.arrow_utils import ensure_table from daft.daft import PyTable as _PyTable -from daft.daft import _read_parquet +from daft.daft import read_parquet as _read_parquet +from daft.daft import read_parquet_bulk as _read_parquet_bulk from daft.daft import read_parquet_statistics as _read_parquet_statistics from daft.datatype import DataType from daft.expressions import Expression, ExpressionsProjection @@ -357,6 +358,20 @@ def read_parquet( _read_parquet(uri=path, columns=columns, start_offset=start_offset, num_rows=num_rows, io_config=io_config) ) + @classmethod + def read_parquet_bulk( + cls, + paths: list[str], + columns: list[str] | None = None, + start_offset: int | None = None, + num_rows: int | None = None, + io_config: IOConfig | None = None, + ) -> list[Table]: + pytables = _read_parquet_bulk( + uris=paths, columns=columns, start_offset=start_offset, num_rows=num_rows, io_config=io_config + ) + return [Table._from_pytable(t) for t in pytables] + @classmethod def read_parquet_statistics( cls, diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 013e28c9dd..363ed2b4b9 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -9,7 +9,7 @@ pub mod pylib { use pyo3::{pyfunction, PyResult, Python}; #[pyfunction] - pub fn _read_parquet( + pub fn read_parquet( py: Python, uri: &str, columns: Option>, @@ -30,6 +30,30 @@ pub mod pylib { }) } + #[pyfunction] + pub fn read_parquet_bulk( + py: Python, + uris: Vec<&str>, + columns: Option>, + start_offset: Option, + num_rows: Option, + io_config: Option, + ) -> PyResult> { + py.allow_threads(|| { + let io_client = get_io_client(io_config.unwrap_or_default().config.into())?; + Ok(crate::read::read_parquet_bulk( + uris.as_ref(), + columns.as_deref(), + start_offset, + num_rows, + io_client, + )? + .into_iter() + .map(|v| v.into()) + .collect()) + }) + } + #[pyfunction] pub fn read_parquet_schema( py: Python, @@ -55,7 +79,8 @@ pub mod pylib { } } pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { - parent.add_wrapped(wrap_pyfunction!(pylib::_read_parquet))?; + parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet))?; + parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_bulk))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_schema))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_statistics))?; Ok(()) diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 6e44592ec2..8487340200 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -9,22 +9,19 @@ use daft_core::{ }; use daft_io::{get_runtime, IOClient}; use daft_table::Table; -use futures::future::join_all; +use futures::future::{join_all, try_join_all}; use snafu::ResultExt; use crate::{file::ParquetReaderBuilder, JoinSnafu}; -pub fn read_parquet( +async fn read_parquet_single( uri: &str, columns: Option<&[&str]>, start_offset: Option, num_rows: Option, io_client: Arc, ) -> DaftResult { - let runtime_handle = get_runtime(true)?; - let _rt_guard = runtime_handle.enter(); - let builder = runtime_handle - .block_on(async { ParquetReaderBuilder::from_uri(uri, io_client.clone()).await })?; + let builder = ParquetReaderBuilder::from_uri(uri, io_client.clone()).await?; let builder = if let Some(columns) = columns { builder.prune_columns(columns)? @@ -38,7 +35,7 @@ pub fn read_parquet( let parquet_reader = builder.build()?; let ranges = parquet_reader.prebuffer_ranges(io_client)?; - let table = runtime_handle.block_on(async { parquet_reader.read_from_ranges(ranges).await })?; + let table = parquet_reader.read_from_ranges(ranges).await?; match (start_offset, num_rows) { (None, None) if metadata_num_rows != table.len() => { @@ -81,6 +78,51 @@ pub fn read_parquet( Ok(table) } +pub fn read_parquet( + uri: &str, + columns: Option<&[&str]>, + start_offset: Option, + num_rows: Option, + io_client: Arc, +) -> DaftResult
{ + let runtime_handle = get_runtime(true)?; + let _rt_guard = runtime_handle.enter(); + runtime_handle.block_on(async { + read_parquet_single(uri, columns, start_offset, num_rows, io_client).await + }) +} + +pub fn read_parquet_bulk( + uris: &[&str], + columns: Option<&[&str]>, + start_offset: Option, + num_rows: Option, + io_client: Arc, +) -> DaftResult> { + let runtime_handle = get_runtime(true)?; + let _rt_guard = runtime_handle.enter(); + let owned_columns = columns.map(|s| s.iter().map(|v| String::from(*v)).collect::>()); + + let tables = runtime_handle + .block_on(async move { + try_join_all(uris.iter().map(|uri| { + let uri = uri.to_string(); + let owned_columns = owned_columns.clone(); + let io_client = io_client.clone(); + tokio::task::spawn(async move { + let columns = owned_columns + .as_ref() + .map(|s| s.iter().map(AsRef::as_ref).collect::>()); + read_parquet_single(&uri, columns.as_deref(), start_offset, num_rows, io_client) + .await + }) + })) + .await + }) + .context(JoinSnafu { path: "UNKNOWN" })?; + tables.into_iter().collect::>>() +} + pub fn read_parquet_schema(uri: &str, io_client: Arc) -> DaftResult { let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index 95463abe75..f2743f349b 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -186,6 +186,17 @@ def test_parquet_read_table(parquet_file): pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) +@pytest.mark.integration() +def test_parquet_read_table_bulk(parquet_file): + _, url = parquet_file + daft_native_reads = Table.read_parquet_bulk([url] * 2, io_config=IOConfig(s3=S3Config(anonymous=True))) + pa_read = Table.from_arrow(read_parquet_with_pyarrow(url)) + + for daft_native_read in daft_native_reads: + assert daft_native_read.schema() == pa_read.schema() + pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) + + @pytest.mark.integration() def test_parquet_read_df(parquet_file): _, url = parquet_file