diff --git a/src/daft-io/src/config.rs b/src/daft-io/src/config.rs index 98f7e5c8d0..10a0f3d971 100644 --- a/src/daft-io/src/config.rs +++ b/src/daft-io/src/config.rs @@ -10,8 +10,11 @@ pub struct S3Config { pub key_id: Option, pub session_token: Option, pub access_key: Option, - pub retry_initial_backoff_ms: u32, + pub retry_initial_backoff_ms: u64, + pub connect_timeout_ms: u64, + pub read_timeout_ms: u64, pub num_tries: u32, + pub retry_mode: Option, pub anonymous: bool, } @@ -24,7 +27,10 @@ impl Default for S3Config { session_token: None, access_key: None, retry_initial_backoff_ms: 1000, + connect_timeout_ms: 60_000, + read_timeout_ms: 60_000, num_tries: 5, + retry_mode: Some("standard".to_string()), anonymous: false, } } @@ -40,8 +46,11 @@ impl Display for S3Config { key_id: {:?} session_token: {:?}, access_key: {:?} - retry_initial_backoff_ms: {:?}, + retry_initial_backoff_ms: {}, + connect_timeout_ms: {}, + read_timeout_ms: {}, num_tries: {:?}, + retry_mode: {:?}, anonymous: {}", self.region_name, self.endpoint_url, @@ -49,7 +58,10 @@ impl Display for S3Config { self.session_token, self.access_key, self.retry_initial_backoff_ms, + self.connect_timeout_ms, + self.read_timeout_ms, self.num_tries, + self.retry_mode, self.anonymous ) } diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 5a812bcd0c..268880402b 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -49,7 +49,7 @@ pub enum Error { #[snafu(display("Invalid Argument: {:?}", msg))] InvalidArgument { msg: String }, - #[snafu(display("Unable to open file {}: {}", path, source))] + #[snafu(display("Unable to open file {}: {:?}", path, source))] UnableToOpenFile { path: String, source: DynError }, #[snafu(display("Unable to read data from file {}: {}", path, source))] @@ -83,6 +83,9 @@ pub enum Error { #[snafu(display("Source not yet implemented: {}", store))] NotImplementedSource { store: String }, + #[snafu(display("Unhandled Error for path: {}\nDetails:\n{}", path, msg))] + Unhandled { path: String, msg: String }, + #[snafu(display("Error joining spawned task: {}", source), context(false))] JoinError { source: tokio::task::JoinError }, } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 9843d0ffb9..187579e31a 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -12,7 +12,10 @@ use pyo3::prelude::*; /// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment /// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials /// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms +/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds +/// read_timeout_ms: Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 60 seconds /// num_tries: Number of attempts to make a connection, defaults to 5 +/// retry_mode: Retry Mode when a request fails, current supported values are `standard` and `adaptive` /// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials /// /// Example: @@ -138,8 +141,11 @@ impl S3Config { key_id: Option, session_token: Option, access_key: Option, - retry_initial_backoff_ms: Option, + retry_initial_backoff_ms: Option, + connect_timeout_ms: Option, + read_timeout_ms: Option, num_tries: Option, + retry_mode: Option, anonymous: Option, ) -> Self { let def = config::S3Config::default(); @@ -152,7 +158,10 @@ impl S3Config { access_key: access_key.or(def.access_key), retry_initial_backoff_ms: retry_initial_backoff_ms .unwrap_or(def.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(def.read_timeout_ms), num_tries: num_tries.unwrap_or(def.num_tries), + retry_mode: retry_mode.or(def.retry_mode), anonymous: anonymous.unwrap_or(def.anonymous), }, } @@ -194,15 +203,33 @@ impl S3Config { /// AWS Retry Initial Backoff Time in Milliseconds #[getter] - pub fn retry_initial_backoff_ms(&self) -> PyResult { + pub fn retry_initial_backoff_ms(&self) -> PyResult { Ok(self.config.retry_initial_backoff_ms) } + /// AWS Connection Timeout in Milliseconds + #[getter] + pub fn connect_timeout_ms(&self) -> PyResult { + Ok(self.config.connect_timeout_ms) + } + + /// AWS Read Timeout in Milliseconds + #[getter] + pub fn read_timeout_ms(&self) -> PyResult { + Ok(self.config.read_timeout_ms) + } + /// AWS Number Retries #[getter] pub fn num_tries(&self) -> PyResult { Ok(self.config.num_tries) } + + /// AWS Retry Mode + #[getter] + pub fn retry_mode(&self) -> PyResult> { + Ok(self.config.retry_mode.clone()) + } } #[pymethods] diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 6278546e93..fabccb1261 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -1,4 +1,6 @@ use async_trait::async_trait; +use aws_config::retry::RetryMode; +use aws_config::timeout::TimeoutConfig; use aws_smithy_async::rt::sleep::TokioSleep; use reqwest::StatusCode; use s3::operation::head_object::HeadObjectError; @@ -12,7 +14,7 @@ use aws_sig_auth::signer::SigningRequirements; use futures::{StreamExt, TryStreamExt}; use s3::client::customize::Response; use s3::config::{Credentials, Region}; -use s3::error::SdkError; +use s3::error::{DisplayErrorContext, SdkError}; use s3::operation::get_object::GetObjectError; use snafu::{ensure, IntoError, ResultExt, Snafu}; use url::ParseError; @@ -92,6 +94,10 @@ impl From for super::Error { path, source: no_such_key.into(), }, + GetObjectError::Unhandled(v) => super::Error::Unhandled { + path, + msg: DisplayErrorContext(v).to_string(), + }, err => super::Error::UnableToOpenFile { path, source: err.into(), @@ -102,6 +108,10 @@ impl From for super::Error { path, source: no_such_key.into(), }, + HeadObjectError::Unhandled(v) => super::Error::Unhandled { + path, + msg: DisplayErrorContext(v).to_string(), + }, err => super::Error::UnableToOpenFile { path, source: err.into(), @@ -160,13 +170,29 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> ); let retry_config = s3::config::retry::RetryConfig::standard() .with_max_attempts(config.num_tries) - .with_initial_backoff(Duration::from_millis( - config.retry_initial_backoff_ms as u64, - )); + .with_initial_backoff(Duration::from_millis(config.retry_initial_backoff_ms)); + + let retry_config = if let Some(retry_mode) = &config.retry_mode { + if retry_mode.trim().eq_ignore_ascii_case("adaptive") { + retry_config.with_retry_mode(RetryMode::Adaptive) + } else if retry_mode.trim().eq_ignore_ascii_case("standard") { + retry_config + } else { + return Err(crate::Error::InvalidArgument { msg: format!("Invalid Retry Mode, Daft S3 client currently only supports standard and adaptive, got {}", retry_mode) }); + } + } else { + retry_config + }; + let builder = builder.retry_config(retry_config); let sleep_impl = Arc::new(TokioSleep::new()); let builder = builder.sleep_impl(sleep_impl); + let timeout_config = TimeoutConfig::builder() + .connect_timeout(Duration::from_millis(config.connect_timeout_ms)) + .read_timeout(Duration::from_millis(config.read_timeout_ms)) + .build(); + let builder = builder.timeout_config(timeout_config); let builder = if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( diff --git a/src/daft-plan/src/source_info.rs b/src/daft-plan/src/source_info.rs index 409fdf7513..f7d106e6a5 100644 --- a/src/daft-plan/src/source_info.rs +++ b/src/daft-plan/src/source_info.rs @@ -249,7 +249,7 @@ impl FileFormatConfig { #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub struct ParquetSourceConfig { pub use_native_downloader: bool, - pub io_config: Option, + pub io_config: Box>, } #[cfg(feature = "python")] @@ -259,7 +259,7 @@ impl ParquetSourceConfig { fn new(use_native_downloader: bool, io_config: Option) -> Self { Self { use_native_downloader, - io_config: io_config.map(|c| c.config), + io_config: io_config.map(|c| c.config).into(), } } @@ -270,7 +270,7 @@ impl ParquetSourceConfig { #[getter] fn get_io_config(&self) -> PyResult> { - Ok(self.io_config.as_ref().map(|c| c.clone().into())) + Ok(self.io_config.clone().map(|c| c.into())) } } diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index 4f23e61184..ca5a923b5b 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -64,11 +64,12 @@ def nginx_config() -> tuple[str, pathlib.Path]: ) -@pytest.fixture(scope="session") -def retry_server_s3_config() -> daft.io.IOConfig: +@pytest.fixture(scope="session", params=["standard", "adaptive"], ids=["standard", "adaptive"]) +def retry_server_s3_config(request) -> daft.io.IOConfig: """Returns the URL to the local retry_server fixture""" + retry_mode = request.param return daft.io.IOConfig( - s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10), + s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10, retry_mode=retry_mode) ) diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index 86e0479889..85817a0ed7 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -48,3 +48,43 @@ def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_confi assert len(data["data"]) == 6 for img_bytes in data["data"]: assert img_bytes is not None + + +@pytest.mark.integration() +def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeout(small_images_s3_paths): + data = {"urls": small_images_s3_paths} + df = daft.from_pydict(data) + + connect_timeout_config = daft.io.IOConfig( + s3=daft.io.S3Config( + # NOTE: no keys or endpoints specified for an AWS public s3 bucket + region_name="us-west-2", + anonymous=True, + connect_timeout_ms=1, + ) + ) + + with pytest.raises(ValueError, match="HTTP connect timeout"): + df = df.with_column( + "data", df["urls"].url.download(io_config=connect_timeout_config, use_native_downloader=True) + ).collect() + + +@pytest.mark.integration() +def test_url_download_aws_s3_public_bucket_native_downloader_with_read_timeout(small_images_s3_paths): + data = {"urls": small_images_s3_paths} + df = daft.from_pydict(data) + + read_timeout_config = daft.io.IOConfig( + s3=daft.io.S3Config( + # NOTE: no keys or endpoints specified for an AWS public s3 bucket + region_name="us-west-2", + anonymous=True, + read_timeout_ms=1, + ) + ) + + with pytest.raises(ValueError, match="HTTP read timeout"): + df = df.with_column( + "data", df["urls"].url.download(io_config=read_timeout_config, use_native_downloader=True) + ).collect()