Skip to content

Commit

Permalink
[BUG] Respect multithreaded_io flag when reading parquet (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Sep 9, 2023
1 parent 1c0087a commit aa1ee28
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/daft-parquet/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub mod pylib {
num_rows,
row_groups.as_deref(),
io_client,
multithreaded_io.unwrap_or(true),
&schema_infer_options,
)?
.into())
Expand Down Expand Up @@ -73,6 +74,7 @@ pub mod pylib {
num_rows,
row_groups,
io_client,
multithreaded_io.unwrap_or(true),
&schema_infer_options,
)?
.into_iter()
Expand Down
19 changes: 16 additions & 3 deletions src/daft-parquet/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,18 @@ async fn read_parquet_single(
Ok(table)
}

#[allow(clippy::too_many_arguments)]
pub fn read_parquet(
uri: &str,
columns: Option<&[&str]>,
start_offset: Option<usize>,
num_rows: Option<usize>,
row_groups: Option<&[i64]>,
io_client: Arc<IOClient>,
multithreaded_io: bool,
schema_infer_options: &ParquetSchemaInferenceOptions,
) -> DaftResult<Table> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_runtime(multithreaded_io)?;
let _rt_guard = runtime_handle.enter();
runtime_handle.block_on(async {
read_parquet_single(
Expand All @@ -163,16 +165,18 @@ pub fn read_parquet(
})
}

#[allow(clippy::too_many_arguments)]
pub fn read_parquet_bulk(
uris: &[&str],
columns: Option<&[&str]>,
start_offset: Option<usize>,
num_rows: Option<usize>,
row_groups: Option<Vec<Vec<i64>>>,
io_client: Arc<IOClient>,
multithreaded_io: bool,
schema_infer_options: &ParquetSchemaInferenceOptions,
) -> DaftResult<Vec<Table>> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_runtime(multithreaded_io)?;
let _rt_guard = runtime_handle.enter();
let owned_columns = columns.map(|s| s.iter().map(|v| String::from(*v)).collect::<Vec<_>>());
if let Some(ref row_groups) = row_groups {
Expand Down Expand Up @@ -321,7 +325,16 @@ mod tests {

let io_client = Arc::new(IOClient::new(io_config.into())?);

let table = read_parquet(file, None, None, None, None, io_client, &Default::default())?;
let table = read_parquet(
file,
None,
None,
None,
None,
io_client,
true,
&Default::default(),
)?;
assert_eq!(table.len(), 100);

Ok(())
Expand Down

0 comments on commit aa1ee28

Please sign in to comment.