From aa1ee280c12a11aabde5eda4ce51af336918c27f Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 8 Sep 2023 21:35:25 -0700 Subject: [PATCH] [BUG] Respect `multithreaded_io` flag when reading parquet (#1359) --- src/daft-parquet/src/python.rs | 2 ++ src/daft-parquet/src/read.rs | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index b62f8b414f..39194f2363 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -38,6 +38,7 @@ pub mod pylib { num_rows, row_groups.as_deref(), io_client, + multithreaded_io.unwrap_or(true), &schema_infer_options, )? .into()) @@ -73,6 +74,7 @@ pub mod pylib { num_rows, row_groups, io_client, + multithreaded_io.unwrap_or(true), &schema_infer_options, )? .into_iter() diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 997b5bef5d..8ad3b3f969 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -138,6 +138,7 @@ async fn read_parquet_single( Ok(table) } +#[allow(clippy::too_many_arguments)] pub fn read_parquet( uri: &str, columns: Option<&[&str]>, @@ -145,9 +146,10 @@ pub fn read_parquet( num_rows: Option, row_groups: Option<&[i64]>, io_client: Arc, + multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, ) -> DaftResult { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_runtime(multithreaded_io)?; let _rt_guard = runtime_handle.enter(); runtime_handle.block_on(async { read_parquet_single( @@ -163,6 +165,7 @@ pub fn read_parquet( }) } +#[allow(clippy::too_many_arguments)] pub fn read_parquet_bulk( uris: &[&str], columns: Option<&[&str]>, @@ -170,9 +173,10 @@ pub fn read_parquet_bulk( num_rows: Option, row_groups: Option>>, io_client: Arc, + multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, ) -> DaftResult> { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_runtime(multithreaded_io)?; let _rt_guard = runtime_handle.enter(); let owned_columns = columns.map(|s| s.iter().map(|v| String::from(*v)).collect::>()); if let Some(ref row_groups) = row_groups { @@ -321,7 +325,16 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_parquet(file, None, None, None, None, io_client, &Default::default())?; + let table = read_parquet( + file, + None, + None, + None, + None, + io_client, + true, + &Default::default(), + )?; assert_eq!(table.len(), 100); Ok(())