Skip to content

Commit

Permalink
[FEAT] Native Downloader add Retry Config parameters (#1244)
Browse files Browse the repository at this point in the history
* allows user to specify the number of tries and the initial backoff
duration
```
    io_config=IOConfig(
        s3=S3Config(num_tries=10, retry_initial_backoff_ms=1000),
    )
```
  • Loading branch information
samster25 authored Aug 9, 2023
1 parent 864be06 commit 5e762f5
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"]
aws-sdk-s3 = "0.28.0"
aws-sig-auth = "0.55.3"
aws-sigv4 = "0.55.3"
aws-smithy-async = "0.55.3"
bytes = {workspace = true}
common-error = {path = "../common/error", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand Down
23 changes: 22 additions & 1 deletion src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,33 @@ use std::fmt::Formatter;

use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct S3Config {
pub region_name: Option<String>,
pub endpoint_url: Option<String>,
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub retry_initial_backoff_ms: u32,
pub num_tries: u32,
pub anonymous: bool,
}

impl Default for S3Config {
fn default() -> Self {
S3Config {
region_name: None,
endpoint_url: None,
key_id: None,
session_token: None,
access_key: None,
retry_initial_backoff_ms: 1000,
num_tries: 5,
anonymous: false,
}
}
}

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IOConfig {
pub s3: S3Config,
Expand All @@ -28,12 +45,16 @@ impl Display for S3Config {
key_id: {:?}
session_token: {:?},
access_key: {:?}
retry_initial_backoff_ms: {:?},
num_tries: {:?},
anonymous: {}",
self.region_name,
self.endpoint_url,
self.key_id,
self.session_token,
self.access_key,
self.retry_initial_backoff_ms,
self.num_tries,
self.anonymous
)
}
Expand Down
41 changes: 34 additions & 7 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use pyo3::prelude::*;
/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment
/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment
/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials
/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms
/// num_tries: Number of attempts to make a connection, defaults to 5
/// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials
///
/// Example:
Expand All @@ -28,7 +30,7 @@ pub struct S3Config {
/// s3: Configurations to use when accessing URLs with the `s3://` scheme
///
/// Example:
/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx"))
/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10))
/// >>> daft.read_parquet("s3://some-path", io_config=io_config)
#[derive(Clone, Default)]
#[pyclass]
Expand Down Expand Up @@ -73,23 +75,30 @@ impl IOConfig {

#[pymethods]
impl S3Config {
#[allow(clippy::too_many_arguments)]
#[new]
pub fn new(
region_name: Option<String>,
endpoint_url: Option<String>,
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
retry_initial_backoff_ms: Option<u32>,
num_tries: Option<u32>,
anonymous: Option<bool>,
) -> Self {
let def = config::S3Config::default();
S3Config {
config: config::S3Config {
region_name,
endpoint_url,
key_id,
session_token,
access_key,
anonymous: anonymous.unwrap_or(false),
region_name: region_name.or(def.region_name),
endpoint_url: endpoint_url.or(def.endpoint_url),
key_id: key_id.or(def.key_id),
session_token: session_token.or(def.session_token),
access_key: access_key.or(def.access_key),
retry_initial_backoff_ms: retry_initial_backoff_ms
.unwrap_or(def.retry_initial_backoff_ms),
num_tries: num_tries.unwrap_or(def.num_tries),
anonymous: anonymous.unwrap_or(def.anonymous),
},
}
}
Expand All @@ -116,11 +125,29 @@ impl S3Config {
Ok(self.config.key_id.clone())
}

/// AWS Session Token
#[getter]
pub fn session_token(&self) -> PyResult<Option<String>> {
Ok(self.config.session_token.clone())
}

/// AWS Secret Access Key
#[getter]
pub fn access_key(&self) -> PyResult<Option<String>> {
Ok(self.config.access_key.clone())
}

/// AWS Retry Initial Backoff Time in Milliseconds
#[getter]
pub fn retry_initial_backoff_ms(&self) -> PyResult<u32> {
Ok(self.config.retry_initial_backoff_ms)
}

/// AWS Number Retries
#[getter]
pub fn num_tries(&self) -> PyResult<u32> {
Ok(self.config.num_tries)
}
}

impl From<config::IOConfig> for IOConfig {
Expand Down
22 changes: 20 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use async_trait::async_trait;
use aws_smithy_async::rt::sleep::TokioSleep;
use reqwest::StatusCode;
use s3::operation::head_object::HeadObjectError;

use crate::config::S3Config;
use crate::SourceType;
use crate::{InvalidArgumentSnafu, SourceType};
use aws_config::SdkConfig;
use aws_credential_types::cache::ProvideCachedCredentials;
use aws_credential_types::provider::error::CredentialsError;
Expand All @@ -13,7 +14,7 @@ use s3::client::customize::Response;
use s3::config::{Credentials, Region};
use s3::error::SdkError;
use s3::operation::get_object::GetObjectError;
use snafu::{IntoError, ResultExt, Snafu};
use snafu::{ensure, IntoError, ResultExt, Snafu};
use url::ParseError;

use super::object_io::{GetResult, ObjectSource};
Expand All @@ -25,6 +26,7 @@ use std::collections::HashMap;
use std::ops::Range;
use std::string::FromUtf8Error;
use std::sync::Arc;
use std::time::Duration;
pub(crate) struct S3LikeSource {
region_to_client_map: tokio::sync::RwLock<HashMap<Region, Arc<s3::Client>>>,
default_region: Region,
Expand Down Expand Up @@ -140,6 +142,22 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
builder
};

ensure!(
config.num_tries > 0,
InvalidArgumentSnafu {
msg: "num_tries must be greater than zero"
}
);
let retry_config = s3::config::retry::RetryConfig::standard()
.with_max_attempts(config.num_tries)
.with_initial_backoff(Duration::from_millis(
config.retry_initial_backoff_ms as u64,
));
let builder = builder.retry_config(retry_config);

let sleep_impl = Arc::new(TokioSleep::new());
let builder = builder.sleep_impl(sleep_impl);

let builder = if config.access_key.is_some() && config.key_id.is_some() {
let creds = Credentials::from_keys(
config.key_id.clone().unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def nginx_config() -> tuple[str, pathlib.Path]:
def retry_server_s3_config() -> daft.io.IOConfig:
"""Returns the URL to the local retry_server fixture"""
return daft.io.IOConfig(
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001"),
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@


@pytest.mark.integration()
@pytest.mark.skip(
reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently.
Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails.
We should be able to get around this with a more generous retry policy, with larger increments between backoffs.
"""
)
def test_url_download_local_retry_server(retry_server_s3_config):
bucket = "80-per-second-rate-limited-gets-bucket"
data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]}
Expand Down

0 comments on commit 5e762f5

Please sign in to comment.