Skip to content

Commit

Permalink
[FEAT] add flag to use multithreaded io for parquet_read_table (#1298)
Browse files Browse the repository at this point in the history
* add flag that allows to toggle multithreaded_io
  • Loading branch information
samster25 committed Aug 24, 2023
1 parent 32084d8 commit 7fa9e64
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
13 changes: 12 additions & 1 deletion daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def read_parquet(
start_offset: int | None = None,
num_rows: int | None = None,
io_config: IOConfig | None = None,
multithreaded_io: bool | None = None,
coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(),
) -> Table:
return Table._from_pytable(
Expand All @@ -377,6 +378,7 @@ def read_parquet(
start_offset=start_offset,
num_rows=num_rows,
io_config=io_config,
multithreaded_io=multithreaded_io,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit,
)
)
Expand All @@ -389,6 +391,7 @@ def read_parquet_bulk(
start_offset: int | None = None,
num_rows: int | None = None,
io_config: IOConfig | None = None,
multithreaded_io: bool | None = None,
coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(),
) -> list[Table]:
pytables = _read_parquet_bulk(
Expand All @@ -397,6 +400,7 @@ def read_parquet_bulk(
start_offset=start_offset,
num_rows=num_rows,
io_config=io_config,
multithreaded_io=multithreaded_io,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit,
)
return [Table._from_pytable(t) for t in pytables]
Expand All @@ -406,11 +410,18 @@ def read_parquet_statistics(
cls,
paths: Series | list[str],
io_config: IOConfig | None = None,
multithreaded_io: bool | None = None,
) -> Table:
if not isinstance(paths, Series):
paths = Series.from_pylist(paths, name="uris")
assert paths.name() == "uris", f"Expected input series to have name 'uris', but found: {paths.name()}"
return Table._from_pytable(_read_parquet_statistics(uris=paths._series, io_config=io_config))
return Table._from_pytable(
_read_parquet_statistics(
uris=paths._series,
io_config=io_config,
multithreaded_io=multithreaded_io,
)
)


def _trim_pyarrow_large_arrays(arr: pa.ChunkedArray) -> pa.ChunkedArray:
Expand Down
14 changes: 8 additions & 6 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> {
_ => Err(Error::NotImplementedSource { store: scheme }),
}
}
type CacheKey = (bool, Arc<IOConfig>);

lazy_static! {
static ref THREADED_RUNTIME: Arc<tokio::runtime::Runtime> = Arc::new(
Expand All @@ -257,23 +258,24 @@ lazy_static! {
.build()
.unwrap()
);
static ref CLIENT_CACHE: tokio::sync::RwLock<HashMap<IOConfig, Arc<IOClient>>> =
static ref CLIENT_CACHE: tokio::sync::RwLock<HashMap<CacheKey, Arc<IOClient>>> =
tokio::sync::RwLock::new(HashMap::new());
}

pub fn get_io_client(config: Arc<IOConfig>) -> DaftResult<Arc<IOClient>> {
pub fn get_io_client(multi_thread: bool, config: Arc<IOConfig>) -> DaftResult<Arc<IOClient>> {
let read_handle = CLIENT_CACHE.blocking_read();
if let Some(client) = read_handle.get(&config) {
let key = (multi_thread, config.clone());
if let Some(client) = read_handle.get(&key) {
Ok(client.clone())
} else {
drop(read_handle);

let mut w_handle = CLIENT_CACHE.blocking_write();
if let Some(client) = w_handle.get(&config) {
if let Some(client) = w_handle.get(&key) {
Ok(client.clone())
} else {
let client = Arc::new(IOClient::new(config.clone())?);
w_handle.insert(config.as_ref().clone(), client.clone());
w_handle.insert(key, client.clone());
Ok(client)
}
}
Expand Down Expand Up @@ -312,7 +314,7 @@ pub fn _url_download(
false => max_connections,
true => max_connections * usize::from(std::thread::available_parallelism()?),
};
let io_client = get_io_client(config)?;
let io_client = get_io_client(multi_thread, config)?;

let fetches = futures::stream::iter(urls.enumerate().map(|(i, url)| {
let owned_url = url.map(|s| s.to_string());
Expand Down
25 changes: 21 additions & 4 deletions src/daft-parquet/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ pub mod pylib {
start_offset: Option<usize>,
num_rows: Option<usize>,
io_config: Option<IOConfig>,
multithreaded_io: Option<bool>,
coerce_int96_timestamp_unit: Option<PyTimeUnit>,
) -> PyResult<PyTable> {
py.allow_threads(|| {
let io_client = get_io_client(io_config.unwrap_or_default().config.into())?;
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
let schema_infer_options = ParquetSchemaInferenceOptions::new(
coerce_int96_timestamp_unit.map(|tu| tu.timeunit),
);
Expand All @@ -38,6 +42,7 @@ pub mod pylib {
})
}

#[allow(clippy::too_many_arguments)]
#[pyfunction]
pub fn read_parquet_bulk(
py: Python,
Expand All @@ -46,10 +51,14 @@ pub mod pylib {
start_offset: Option<usize>,
num_rows: Option<usize>,
io_config: Option<IOConfig>,
multithreaded_io: Option<bool>,
coerce_int96_timestamp_unit: Option<PyTimeUnit>,
) -> PyResult<Vec<PyTable>> {
py.allow_threads(|| {
let io_client = get_io_client(io_config.unwrap_or_default().config.into())?;
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
let schema_infer_options = ParquetSchemaInferenceOptions::new(
coerce_int96_timestamp_unit.map(|tu| tu.timeunit),
);
Expand All @@ -72,13 +81,17 @@ pub mod pylib {
py: Python,
uri: &str,
io_config: Option<IOConfig>,
multithreaded_io: Option<bool>,
coerce_int96_timestamp_unit: Option<PyTimeUnit>,
) -> PyResult<PySchema> {
py.allow_threads(|| {
let schema_infer_options = ParquetSchemaInferenceOptions::new(
coerce_int96_timestamp_unit.map(|tu| tu.timeunit),
);
let io_client = get_io_client(io_config.unwrap_or_default().config.into())?;
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
Ok(Arc::new(crate::read::read_parquet_schema(
uri,
io_client,
Expand All @@ -93,9 +106,13 @@ pub mod pylib {
py: Python,
uris: PySeries,
io_config: Option<IOConfig>,
multithreaded_io: Option<bool>,
) -> PyResult<PyTable> {
py.allow_threads(|| {
let io_client = get_io_client(io_config.unwrap_or_default().config.into())?;
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
Ok(crate::read::read_parquet_statistics(&uris.series, io_client)?.into())
})
}
Expand Down
18 changes: 14 additions & 4 deletions tests/integration/io/parquet/test_reads_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,13 @@ def read_parquet_with_pyarrow(path) -> pa.Table:
@pytest.mark.skipif(
daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner"
)
def test_parquet_read_table(parquet_file, public_storage_io_config):
@pytest.mark.parametrize(
"multithreaded_io",
[False, True],
)
def test_parquet_read_table(parquet_file, public_storage_io_config, multithreaded_io):
_, url = parquet_file
daft_native_read = Table.read_parquet(url, io_config=public_storage_io_config)
daft_native_read = Table.read_parquet(url, io_config=public_storage_io_config, multithreaded_io=multithreaded_io)
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))
assert daft_native_read.schema() == pa_read.schema()
pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())
Expand All @@ -212,9 +216,15 @@ def test_parquet_read_table(parquet_file, public_storage_io_config):
@pytest.mark.skipif(
daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner"
)
def test_parquet_read_table_bulk(parquet_file, public_storage_io_config):
@pytest.mark.parametrize(
"multithreaded_io",
[False, True],
)
def test_parquet_read_table_bulk(parquet_file, public_storage_io_config, multithreaded_io):
_, url = parquet_file
daft_native_reads = Table.read_parquet_bulk([url] * 2, io_config=public_storage_io_config)
daft_native_reads = Table.read_parquet_bulk(
[url] * 2, io_config=public_storage_io_config, multithreaded_io=multithreaded_io
)
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))

for daft_native_read in daft_native_reads:
Expand Down

0 comments on commit 7fa9e64

Please sign in to comment.