From 5e762f5199777192727a0f8dc5d42135b4205c55 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 8 Aug 2023 17:24:24 -0700 Subject: [PATCH] [FEAT] Native Downloader add Retry Config parameters (#1244) * allows user to specify the number of tries and the initial backoff duration ``` io_config=IOConfig( s3=S3Config(num_tries=10, retry_initial_backoff_ms=1000), ) ``` --- Cargo.lock | 1 + src/daft-io/Cargo.toml | 1 + src/daft-io/src/config.rs | 23 ++++++++++- src/daft-io/src/python.rs | 41 +++++++++++++++---- src/daft-io/src/s3_like.rs | 22 +++++++++- tests/integration/io/conftest.py | 2 +- ...test_url_download_s3_local_retry_server.py | 6 --- 7 files changed, 79 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8eb0d88695..ea3ebe3858 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -926,6 +926,7 @@ dependencies = [ "aws-sdk-s3", "aws-sig-auth", "aws-sigv4", + "aws-smithy-async", "bytes", "common-error", "daft-core", diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 8e85a1a465..aae42175d7 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -6,6 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"] aws-sdk-s3 = "0.28.0" aws-sig-auth = "0.55.3" aws-sigv4 = "0.55.3" +aws-smithy-async = "0.55.3" bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} diff --git a/src/daft-io/src/config.rs b/src/daft-io/src/config.rs index f5529a4165..430c430529 100644 --- a/src/daft-io/src/config.rs +++ b/src/daft-io/src/config.rs @@ -3,16 +3,33 @@ use std::fmt::Formatter; use serde::Deserialize; use serde::Serialize; -#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct S3Config { pub region_name: Option, pub endpoint_url: Option, pub key_id: Option, pub session_token: Option, pub access_key: Option, + pub retry_initial_backoff_ms: u32, + pub num_tries: u32, pub anonymous: bool, } +impl Default for S3Config { + fn default() -> Self { + S3Config { + region_name: None, + endpoint_url: None, + key_id: None, + session_token: None, + access_key: None, + retry_initial_backoff_ms: 1000, + num_tries: 5, + anonymous: false, + } + } +} + #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct IOConfig { pub s3: S3Config, @@ -28,12 +45,16 @@ impl Display for S3Config { key_id: {:?} session_token: {:?}, access_key: {:?} + retry_initial_backoff_ms: {:?}, + num_tries: {:?}, anonymous: {}", self.region_name, self.endpoint_url, self.key_id, self.session_token, self.access_key, + self.retry_initial_backoff_ms, + self.num_tries, self.anonymous ) } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 37e7bff99a..aa7f3dc871 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -11,6 +11,8 @@ use pyo3::prelude::*; /// key_id: AWS Access Key ID, defaults to auto-detection from the current environment /// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment /// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials +/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms +/// num_tries: Number of attempts to make a connection, defaults to 5 /// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials /// /// Example: @@ -28,7 +30,7 @@ pub struct S3Config { /// s3: Configurations to use when accessing URLs with the `s3://` scheme /// /// Example: -/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx")) +/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10)) /// >>> daft.read_parquet("s3://some-path", io_config=io_config) #[derive(Clone, Default)] #[pyclass] @@ -73,6 +75,7 @@ impl IOConfig { #[pymethods] impl S3Config { + #[allow(clippy::too_many_arguments)] #[new] pub fn new( region_name: Option, @@ -80,16 +83,22 @@ impl S3Config { key_id: Option, session_token: Option, access_key: Option, + retry_initial_backoff_ms: Option, + num_tries: Option, anonymous: Option, ) -> Self { + let def = config::S3Config::default(); S3Config { config: config::S3Config { - region_name, - endpoint_url, - key_id, - session_token, - access_key, - anonymous: anonymous.unwrap_or(false), + region_name: region_name.or(def.region_name), + endpoint_url: endpoint_url.or(def.endpoint_url), + key_id: key_id.or(def.key_id), + session_token: session_token.or(def.session_token), + access_key: access_key.or(def.access_key), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(def.retry_initial_backoff_ms), + num_tries: num_tries.unwrap_or(def.num_tries), + anonymous: anonymous.unwrap_or(def.anonymous), }, } } @@ -116,11 +125,29 @@ impl S3Config { Ok(self.config.key_id.clone()) } + /// AWS Session Token + #[getter] + pub fn session_token(&self) -> PyResult> { + Ok(self.config.session_token.clone()) + } + /// AWS Secret Access Key #[getter] pub fn access_key(&self) -> PyResult> { Ok(self.config.access_key.clone()) } + + /// AWS Retry Initial Backoff Time in Milliseconds + #[getter] + pub fn retry_initial_backoff_ms(&self) -> PyResult { + Ok(self.config.retry_initial_backoff_ms) + } + + /// AWS Number Retries + #[getter] + pub fn num_tries(&self) -> PyResult { + Ok(self.config.num_tries) + } } impl From for IOConfig { diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index df58031555..ed3da3b361 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -1,9 +1,10 @@ use async_trait::async_trait; +use aws_smithy_async::rt::sleep::TokioSleep; use reqwest::StatusCode; use s3::operation::head_object::HeadObjectError; use crate::config::S3Config; -use crate::SourceType; +use crate::{InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; use aws_credential_types::cache::ProvideCachedCredentials; use aws_credential_types::provider::error::CredentialsError; @@ -13,7 +14,7 @@ use s3::client::customize::Response; use s3::config::{Credentials, Region}; use s3::error::SdkError; use s3::operation::get_object::GetObjectError; -use snafu::{IntoError, ResultExt, Snafu}; +use snafu::{ensure, IntoError, ResultExt, Snafu}; use url::ParseError; use super::object_io::{GetResult, ObjectSource}; @@ -25,6 +26,7 @@ use std::collections::HashMap; use std::ops::Range; use std::string::FromUtf8Error; use std::sync::Arc; +use std::time::Duration; pub(crate) struct S3LikeSource { region_to_client_map: tokio::sync::RwLock>>, default_region: Region, @@ -140,6 +142,22 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> builder }; + ensure!( + config.num_tries > 0, + InvalidArgumentSnafu { + msg: "num_tries must be greater than zero" + } + ); + let retry_config = s3::config::retry::RetryConfig::standard() + .with_max_attempts(config.num_tries) + .with_initial_backoff(Duration::from_millis( + config.retry_initial_backoff_ms as u64, + )); + let builder = builder.retry_config(retry_config); + + let sleep_impl = Arc::new(TokioSleep::new()); + let builder = builder.sleep_impl(sleep_impl); + let builder = if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( config.key_id.clone().unwrap(), diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index bd3ba50518..e8ee3b142b 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -58,7 +58,7 @@ def nginx_config() -> tuple[str, pathlib.Path]: def retry_server_s3_config() -> daft.io.IOConfig: """Returns the URL to the local retry_server fixture""" return daft.io.IOConfig( - s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001"), + s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10), ) diff --git a/tests/integration/io/test_url_download_s3_local_retry_server.py b/tests/integration/io/test_url_download_s3_local_retry_server.py index 0cc8ea593f..686e1caf79 100644 --- a/tests/integration/io/test_url_download_s3_local_retry_server.py +++ b/tests/integration/io/test_url_download_s3_local_retry_server.py @@ -6,12 +6,6 @@ @pytest.mark.integration() -@pytest.mark.skip( - reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently. - Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails. - We should be able to get around this with a more generous retry policy, with larger increments between backoffs. - """ -) def test_url_download_local_retry_server(retry_server_s3_config): bucket = "80-per-second-rate-limited-gets-bucket" data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]}