From 42f1199418cde21551cb14953928a3cd39b0f8b3 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Sun, 10 Sep 2023 18:22:38 -0700 Subject: [PATCH] [FEAT] Add configurable io thread pool size (#1363) * Changes s3 max connections from 25 to 1024 * Changes default `io_thread_pool` size to 8 * Adds function for rust and python to change the number of threads in the `io_pool` --- daft/io/__init__.py | 9 +- src/common/io-config/src/python.rs | 2 +- src/common/io-config/src/s3.rs | 2 +- src/daft-io/src/lib.rs | 47 ++++++++-- src/daft-io/src/python.rs | 92 ++++++++++--------- .../io/test_url_download_public_aws_s3.py | 22 +++++ 6 files changed, 122 insertions(+), 52 deletions(-) diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 8cd278a879..5d34f1edc7 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -2,7 +2,13 @@ import sys -from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config +from daft.daft import ( + AzureConfig, + GCSConfig, + IOConfig, + S3Config, + set_io_pool_num_threads, +) from daft.io._csv import read_csv from daft.io._json import read_json from daft.io._parquet import read_parquet @@ -32,4 +38,5 @@ def _set_linux_cert_paths(): "S3Config", "AzureConfig", "GCSConfig", + "set_io_pool_num_threads", ] diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 342df3bf42..e6785e8137 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -11,7 +11,7 @@ use crate::config; /// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS /// key_id: AWS Access Key ID, defaults to auto-detection from the current environment /// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment -/// max_connections: Maximum number of connections to S3 at any time, defaults to 25 +/// max_connections: Maximum number of connections to S3 at any time, defaults to 1024 /// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials /// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms /// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 6cc309953c..05f9111f9c 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -28,7 +28,7 @@ impl Default for S3Config { key_id: None, session_token: None, access_key: None, - max_connections: 25, + max_connections: 1024, retry_initial_backoff_ms: 1000, connect_timeout_ms: 60_000, read_timeout_ms: 60_000, diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index df39d27cef..03ddac2f3b 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -248,14 +248,19 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { } } type CacheKey = (bool, Arc); - lazy_static! { - static ref THREADED_RUNTIME: Arc = Arc::new( - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - ); + static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); + static ref THREADED_RUNTIME: tokio::sync::RwLock<(Arc, usize)> = + tokio::sync::RwLock::new(( + Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(8.min(*NUM_CPUS)) + .enable_all() + .build() + .unwrap() + ), + 8.min(*NUM_CPUS) + )); static ref CLIENT_CACHE: tokio::sync::RwLock>> = tokio::sync::RwLock::new(HashMap::new()); } @@ -286,8 +291,34 @@ pub fn get_runtime(multi_thread: bool) -> DaftResult Ok(THREADED_RUNTIME.clone()), + true => { + let guard = THREADED_RUNTIME.blocking_read(); + Ok(guard.clone().0) + } + } +} + +pub fn set_io_pool_num_threads(num_threads: usize) -> bool { + { + let guard = THREADED_RUNTIME.blocking_read(); + if guard.1 == num_threads { + return false; + } } + let mut client_guard = CLIENT_CACHE.blocking_write(); + let mut guard = THREADED_RUNTIME.blocking_write(); + + client_guard.clear(); + + guard.1 = num_threads; + guard.0 = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .enable_all() + .build() + .unwrap(), + ); + true } pub fn _url_download( diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index b1f3957650..306ecdc745 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -1,47 +1,57 @@ -use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url}; -use common_error::DaftResult; -use pyo3::{ - prelude::*, - types::{PyDict, PyList}, -}; - pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; +pub use py::register_modules; + +mod py { + use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url}; + use common_error::DaftResult; + use pyo3::{ + prelude::*, + types::{PyDict, PyList}, + }; -#[pyfunction] -fn io_list( - py: Python, - path: String, - multithreaded_io: Option, - io_config: Option, -) -> PyResult<&PyList> { - let lsr: DaftResult = py.allow_threads(|| { - let io_client = get_io_client( - multithreaded_io.unwrap_or(true), - io_config.unwrap_or_default().config.into(), - )?; - let (scheme, path) = parse_url(&path)?; - let runtime_handle = get_runtime(true)?; - let _rt_guard = runtime_handle.enter(); + #[pyfunction] + fn io_list( + py: Python, + path: String, + multithreaded_io: Option, + io_config: Option, + ) -> PyResult<&PyList> { + let lsr: DaftResult = py.allow_threads(|| { + let io_client = get_io_client( + multithreaded_io.unwrap_or(true), + io_config.unwrap_or_default().config.into(), + )?; + let (scheme, path) = parse_url(&path)?; + let runtime_handle = get_runtime(true)?; + let _rt_guard = runtime_handle.enter(); - runtime_handle.block_on(async move { - let source = io_client.get_source(&scheme).await?; - Ok(source.ls(&path, None, None).await?) - }) - }); - let lsr = lsr?; - let mut to_rtn = vec![]; - for file in lsr.files { - let dict = PyDict::new(py); - dict.set_item("type", format!("{:?}", file.filetype))?; - dict.set_item("path", file.filepath)?; - dict.set_item("size", file.size)?; - to_rtn.push(dict); + runtime_handle.block_on(async move { + let source = io_client.get_source(&scheme).await?; + Ok(source.ls(&path, None, None).await?) + }) + }); + let lsr = lsr?; + let mut to_rtn = vec![]; + for file in lsr.files { + let dict = PyDict::new(py); + dict.set_item("type", format!("{:?}", file.filetype))?; + dict.set_item("path", file.filepath)?; + dict.set_item("size", file.size)?; + to_rtn.push(dict); + } + Ok(PyList::new(py, to_rtn)) } - Ok(PyList::new(py, to_rtn)) -} -pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> { - common_io_config::python::register_modules(py, parent)?; - parent.add_function(wrap_pyfunction!(io_list, parent)?)?; - Ok(()) + #[pyfunction] + fn set_io_pool_num_threads(num_threads: i64) -> PyResult { + Ok(crate::set_io_pool_num_threads(num_threads as usize)) + } + + pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> { + common_io_config::python::register_modules(py, parent)?; + parent.add_function(wrap_pyfunction!(io_list, parent)?)?; + parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?; + + Ok(()) + } } diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index 85817a0ed7..8d9469754d 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -50,6 +50,28 @@ def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_confi assert img_bytes is not None +@pytest.mark.integration() +def test_url_download_aws_s3_public_bucket_native_downloader_io_thread_change( + aws_public_s3_config, small_images_s3_paths +): + data = {"urls": small_images_s3_paths} + df = daft.from_pydict(data) + df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True)) + + data = df.to_pydict() + assert len(data["data"]) == 6 + for img_bytes in data["data"]: + assert img_bytes is not None + daft.io.set_io_pool_num_threads(2) + df = daft.from_pydict(data) + df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True)) + + data = df.to_pydict() + assert len(data["data"]) == 6 + for img_bytes in data["data"]: + assert img_bytes is not None + + @pytest.mark.integration() def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeout(small_images_s3_paths): data = {"urls": small_images_s3_paths}