Skip to content

Commit

Permalink
[FEAT] Add configurable io thread pool size (#1363)
Browse files Browse the repository at this point in the history
* Changes s3 max connections from 25 to 1024
* Changes default `io_thread_pool` size to 8
* Adds function for rust and python to change the number of threads in
the `io_pool`
  • Loading branch information
samster25 committed Sep 11, 2023
1 parent c8ffd03 commit 42f1199
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 52 deletions.
9 changes: 8 additions & 1 deletion daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import sys

from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config
from daft.daft import (
AzureConfig,
GCSConfig,
IOConfig,
S3Config,
set_io_pool_num_threads,
)
from daft.io._csv import read_csv
from daft.io._json import read_json
from daft.io._parquet import read_parquet
Expand Down Expand Up @@ -32,4 +38,5 @@ def _set_linux_cert_paths():
"S3Config",
"AzureConfig",
"GCSConfig",
"set_io_pool_num_threads",
]
2 changes: 1 addition & 1 deletion src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::config;
/// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS
/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment
/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment
/// max_connections: Maximum number of connections to S3 at any time, defaults to 25
/// max_connections: Maximum number of connections to S3 at any time, defaults to 1024
/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials
/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms
/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds
Expand Down
2 changes: 1 addition & 1 deletion src/common/io-config/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Default for S3Config {
key_id: None,
session_token: None,
access_key: None,
max_connections: 25,
max_connections: 1024,
retry_initial_backoff_ms: 1000,
connect_timeout_ms: 60_000,
read_timeout_ms: 60_000,
Expand Down
47 changes: 39 additions & 8 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,19 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> {
}
}
type CacheKey = (bool, Arc<IOConfig>);

lazy_static! {
static ref THREADED_RUNTIME: Arc<tokio::runtime::Runtime> = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
);
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
static ref THREADED_RUNTIME: tokio::sync::RwLock<(Arc<tokio::runtime::Runtime>, usize)> =
tokio::sync::RwLock::new((
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(8.min(*NUM_CPUS))
.enable_all()
.build()
.unwrap()
),
8.min(*NUM_CPUS)
));
static ref CLIENT_CACHE: tokio::sync::RwLock<HashMap<CacheKey, Arc<IOClient>>> =
tokio::sync::RwLock::new(HashMap::new());
}
Expand Down Expand Up @@ -286,8 +291,34 @@ pub fn get_runtime(multi_thread: bool) -> DaftResult<Arc<tokio::runtime::Runtime
.enable_all()
.build()?,
)),
true => Ok(THREADED_RUNTIME.clone()),
true => {
let guard = THREADED_RUNTIME.blocking_read();
Ok(guard.clone().0)
}
}
}

pub fn set_io_pool_num_threads(num_threads: usize) -> bool {
{
let guard = THREADED_RUNTIME.blocking_read();
if guard.1 == num_threads {
return false;
}
}
let mut client_guard = CLIENT_CACHE.blocking_write();
let mut guard = THREADED_RUNTIME.blocking_write();

client_guard.clear();

guard.1 = num_threads;
guard.0 = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.unwrap(),
);
true
}

pub fn _url_download(
Expand Down
92 changes: 51 additions & 41 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,57 @@
use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url};
use common_error::DaftResult;
use pyo3::{
prelude::*,
types::{PyDict, PyList},
};

pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig};
pub use py::register_modules;

mod py {
use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url};
use common_error::DaftResult;
use pyo3::{
prelude::*,
types::{PyDict, PyList},
};

#[pyfunction]
fn io_list(
py: Python,
path: String,
multithreaded_io: Option<bool>,
io_config: Option<common_io_config::python::IOConfig>,
) -> PyResult<&PyList> {
let lsr: DaftResult<LSResult> = py.allow_threads(|| {
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
let (scheme, path) = parse_url(&path)?;
let runtime_handle = get_runtime(true)?;
let _rt_guard = runtime_handle.enter();
#[pyfunction]
fn io_list(
py: Python,
path: String,
multithreaded_io: Option<bool>,
io_config: Option<common_io_config::python::IOConfig>,
) -> PyResult<&PyList> {
let lsr: DaftResult<LSResult> = py.allow_threads(|| {
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
let (scheme, path) = parse_url(&path)?;
let runtime_handle = get_runtime(true)?;
let _rt_guard = runtime_handle.enter();

runtime_handle.block_on(async move {
let source = io_client.get_source(&scheme).await?;
Ok(source.ls(&path, None, None).await?)
})
});
let lsr = lsr?;
let mut to_rtn = vec![];
for file in lsr.files {
let dict = PyDict::new(py);
dict.set_item("type", format!("{:?}", file.filetype))?;
dict.set_item("path", file.filepath)?;
dict.set_item("size", file.size)?;
to_rtn.push(dict);
runtime_handle.block_on(async move {
let source = io_client.get_source(&scheme).await?;
Ok(source.ls(&path, None, None).await?)
})
});
let lsr = lsr?;
let mut to_rtn = vec![];
for file in lsr.files {
let dict = PyDict::new(py);
dict.set_item("type", format!("{:?}", file.filetype))?;
dict.set_item("path", file.filepath)?;
dict.set_item("size", file.size)?;
to_rtn.push(dict);
}
Ok(PyList::new(py, to_rtn))
}
Ok(PyList::new(py, to_rtn))
}

pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> {
common_io_config::python::register_modules(py, parent)?;
parent.add_function(wrap_pyfunction!(io_list, parent)?)?;
Ok(())
#[pyfunction]
fn set_io_pool_num_threads(num_threads: i64) -> PyResult<bool> {
Ok(crate::set_io_pool_num_threads(num_threads as usize))
}

pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> {
common_io_config::python::register_modules(py, parent)?;
parent.add_function(wrap_pyfunction!(io_list, parent)?)?;
parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?;

Ok(())
}
}
22 changes: 22 additions & 0 deletions tests/integration/io/test_url_download_public_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_confi
assert img_bytes is not None


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_native_downloader_io_thread_change(
aws_public_s3_config, small_images_s3_paths
):
data = {"urls": small_images_s3_paths}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True))

data = df.to_pydict()
assert len(data["data"]) == 6
for img_bytes in data["data"]:
assert img_bytes is not None
daft.io.set_io_pool_num_threads(2)
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True))

data = df.to_pydict()
assert len(data["data"]) == 6
for img_bytes in data["data"]:
assert img_bytes is not None


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeout(small_images_s3_paths):
data = {"urls": small_images_s3_paths}
Expand Down

0 comments on commit 42f1199

Please sign in to comment.