Skip to content

Commit

Permalink
[FEAT] Add Retry Mode, connection timeout, and read timeout to S3Conf…
Browse files Browse the repository at this point in the history
…ig (#1293)

* Adds Retry Mode (allows for standard or adaptive retries)
* Adds connection timeout (now defaults to 60 seconds)
* Adds Read Timeout (now defaults to 60 seconds)
  • Loading branch information
samster25 committed Aug 24, 2023
1 parent 1c1abfb commit 32084d8
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 15 deletions.
16 changes: 14 additions & 2 deletions src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ pub struct S3Config {
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub retry_initial_backoff_ms: u32,
pub retry_initial_backoff_ms: u64,
pub connect_timeout_ms: u64,
pub read_timeout_ms: u64,
pub num_tries: u32,
pub retry_mode: Option<String>,
pub anonymous: bool,
}

Expand All @@ -24,7 +27,10 @@ impl Default for S3Config {
session_token: None,
access_key: None,
retry_initial_backoff_ms: 1000,
connect_timeout_ms: 60_000,
read_timeout_ms: 60_000,
num_tries: 5,
retry_mode: Some("standard".to_string()),
anonymous: false,
}
}
Expand All @@ -40,16 +46,22 @@ impl Display for S3Config {
key_id: {:?}
session_token: {:?},
access_key: {:?}
retry_initial_backoff_ms: {:?},
retry_initial_backoff_ms: {},
connect_timeout_ms: {},
read_timeout_ms: {},
num_tries: {:?},
retry_mode: {:?},
anonymous: {}",
self.region_name,
self.endpoint_url,
self.key_id,
self.session_token,
self.access_key,
self.retry_initial_backoff_ms,
self.connect_timeout_ms,
self.read_timeout_ms,
self.num_tries,
self.retry_mode,
self.anonymous
)
}
Expand Down
5 changes: 4 additions & 1 deletion src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub enum Error {
#[snafu(display("Invalid Argument: {:?}", msg))]
InvalidArgument { msg: String },

#[snafu(display("Unable to open file {}: {}", path, source))]
#[snafu(display("Unable to open file {}: {:?}", path, source))]
UnableToOpenFile { path: String, source: DynError },

#[snafu(display("Unable to read data from file {}: {}", path, source))]
Expand Down Expand Up @@ -83,6 +83,9 @@ pub enum Error {
#[snafu(display("Source not yet implemented: {}", store))]
NotImplementedSource { store: String },

#[snafu(display("Unhandled Error for path: {}\nDetails:\n{}", path, msg))]
Unhandled { path: String, msg: String },

#[snafu(display("Error joining spawned task: {}", source), context(false))]
JoinError { source: tokio::task::JoinError },
}
Expand Down
31 changes: 29 additions & 2 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use pyo3::prelude::*;
/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment
/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials
/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms
/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds
/// read_timeout_ms: Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 60 seconds
/// num_tries: Number of attempts to make a connection, defaults to 5
/// retry_mode: Retry Mode when a request fails, current supported values are `standard` and `adaptive`
/// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials
///
/// Example:
Expand Down Expand Up @@ -138,8 +141,11 @@ impl S3Config {
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
retry_initial_backoff_ms: Option<u32>,
retry_initial_backoff_ms: Option<u64>,
connect_timeout_ms: Option<u64>,
read_timeout_ms: Option<u64>,
num_tries: Option<u32>,
retry_mode: Option<String>,
anonymous: Option<bool>,
) -> Self {
let def = config::S3Config::default();
Expand All @@ -152,7 +158,10 @@ impl S3Config {
access_key: access_key.or(def.access_key),
retry_initial_backoff_ms: retry_initial_backoff_ms
.unwrap_or(def.retry_initial_backoff_ms),
connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms),
read_timeout_ms: read_timeout_ms.unwrap_or(def.read_timeout_ms),
num_tries: num_tries.unwrap_or(def.num_tries),
retry_mode: retry_mode.or(def.retry_mode),
anonymous: anonymous.unwrap_or(def.anonymous),
},
}
Expand Down Expand Up @@ -194,15 +203,33 @@ impl S3Config {

/// AWS Retry Initial Backoff Time in Milliseconds
#[getter]
pub fn retry_initial_backoff_ms(&self) -> PyResult<u32> {
pub fn retry_initial_backoff_ms(&self) -> PyResult<u64> {
Ok(self.config.retry_initial_backoff_ms)
}

/// AWS Connection Timeout in Milliseconds
#[getter]
pub fn connect_timeout_ms(&self) -> PyResult<u64> {
Ok(self.config.connect_timeout_ms)
}

/// AWS Read Timeout in Milliseconds
#[getter]
pub fn read_timeout_ms(&self) -> PyResult<u64> {
Ok(self.config.read_timeout_ms)
}

/// AWS Number Retries
#[getter]
pub fn num_tries(&self) -> PyResult<u32> {
Ok(self.config.num_tries)
}

/// AWS Retry Mode
#[getter]
pub fn retry_mode(&self) -> PyResult<Option<String>> {
Ok(self.config.retry_mode.clone())
}
}

#[pymethods]
Expand Down
34 changes: 30 additions & 4 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_trait::async_trait;
use aws_config::retry::RetryMode;
use aws_config::timeout::TimeoutConfig;
use aws_smithy_async::rt::sleep::TokioSleep;
use reqwest::StatusCode;
use s3::operation::head_object::HeadObjectError;
Expand All @@ -12,7 +14,7 @@ use aws_sig_auth::signer::SigningRequirements;
use futures::{StreamExt, TryStreamExt};
use s3::client::customize::Response;
use s3::config::{Credentials, Region};
use s3::error::SdkError;
use s3::error::{DisplayErrorContext, SdkError};
use s3::operation::get_object::GetObjectError;
use snafu::{ensure, IntoError, ResultExt, Snafu};
use url::ParseError;
Expand Down Expand Up @@ -92,6 +94,10 @@ impl From<Error> for super::Error {
path,
source: no_such_key.into(),
},
GetObjectError::Unhandled(v) => super::Error::Unhandled {
path,
msg: DisplayErrorContext(v).to_string(),
},
err => super::Error::UnableToOpenFile {
path,
source: err.into(),
Expand All @@ -102,6 +108,10 @@ impl From<Error> for super::Error {
path,
source: no_such_key.into(),
},
HeadObjectError::Unhandled(v) => super::Error::Unhandled {
path,
msg: DisplayErrorContext(v).to_string(),
},
err => super::Error::UnableToOpenFile {
path,
source: err.into(),
Expand Down Expand Up @@ -160,13 +170,29 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
);
let retry_config = s3::config::retry::RetryConfig::standard()
.with_max_attempts(config.num_tries)
.with_initial_backoff(Duration::from_millis(
config.retry_initial_backoff_ms as u64,
));
.with_initial_backoff(Duration::from_millis(config.retry_initial_backoff_ms));

let retry_config = if let Some(retry_mode) = &config.retry_mode {
if retry_mode.trim().eq_ignore_ascii_case("adaptive") {
retry_config.with_retry_mode(RetryMode::Adaptive)
} else if retry_mode.trim().eq_ignore_ascii_case("standard") {
retry_config
} else {
return Err(crate::Error::InvalidArgument { msg: format!("Invalid Retry Mode, Daft S3 client currently only supports standard and adaptive, got {}", retry_mode) });
}
} else {
retry_config
};

let builder = builder.retry_config(retry_config);

let sleep_impl = Arc::new(TokioSleep::new());
let builder = builder.sleep_impl(sleep_impl);
let timeout_config = TimeoutConfig::builder()
.connect_timeout(Duration::from_millis(config.connect_timeout_ms))
.read_timeout(Duration::from_millis(config.read_timeout_ms))
.build();
let builder = builder.timeout_config(timeout_config);

let builder = if config.access_key.is_some() && config.key_id.is_some() {
let creds = Credentials::from_keys(
Expand Down
6 changes: 3 additions & 3 deletions src/daft-plan/src/source_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl FileFormatConfig {
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub struct ParquetSourceConfig {
pub use_native_downloader: bool,
pub io_config: Option<IOConfig>,
pub io_config: Box<Option<IOConfig>>,
}

#[cfg(feature = "python")]
Expand All @@ -259,7 +259,7 @@ impl ParquetSourceConfig {
fn new(use_native_downloader: bool, io_config: Option<PyIOConfig>) -> Self {
Self {
use_native_downloader,
io_config: io_config.map(|c| c.config),
io_config: io_config.map(|c| c.config).into(),
}
}

Expand All @@ -270,7 +270,7 @@ impl ParquetSourceConfig {

#[getter]
fn get_io_config(&self) -> PyResult<Option<PyIOConfig>> {
Ok(self.io_config.as_ref().map(|c| c.clone().into()))
Ok(self.io_config.clone().map(|c| c.into()))
}
}

Expand Down
7 changes: 4 additions & 3 deletions tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def nginx_config() -> tuple[str, pathlib.Path]:
)


@pytest.fixture(scope="session")
def retry_server_s3_config() -> daft.io.IOConfig:
@pytest.fixture(scope="session", params=["standard", "adaptive"], ids=["standard", "adaptive"])
def retry_server_s3_config(request) -> daft.io.IOConfig:
"""Returns the URL to the local retry_server fixture"""
retry_mode = request.param
return daft.io.IOConfig(
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10),
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10, retry_mode=retry_mode)
)


Expand Down
40 changes: 40 additions & 0 deletions tests/integration/io/test_url_download_public_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,43 @@ def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_confi
assert len(data["data"]) == 6
for img_bytes in data["data"]:
assert img_bytes is not None


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeout(small_images_s3_paths):
data = {"urls": small_images_s3_paths}
df = daft.from_pydict(data)

connect_timeout_config = daft.io.IOConfig(
s3=daft.io.S3Config(
# NOTE: no keys or endpoints specified for an AWS public s3 bucket
region_name="us-west-2",
anonymous=True,
connect_timeout_ms=1,
)
)

with pytest.raises(ValueError, match="HTTP connect timeout"):
df = df.with_column(
"data", df["urls"].url.download(io_config=connect_timeout_config, use_native_downloader=True)
).collect()


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_native_downloader_with_read_timeout(small_images_s3_paths):
data = {"urls": small_images_s3_paths}
df = daft.from_pydict(data)

read_timeout_config = daft.io.IOConfig(
s3=daft.io.S3Config(
# NOTE: no keys or endpoints specified for an AWS public s3 bucket
region_name="us-west-2",
anonymous=True,
read_timeout_ms=1,
)
)

with pytest.raises(ValueError, match="HTTP read timeout"):
df = df.with_column(
"data", df["urls"].url.download(io_config=read_timeout_config, use_native_downloader=True)
).collect()

0 comments on commit 32084d8

Please sign in to comment.