Skip to content

Commit

Permalink
[FEAT] Add flag to limit number of connections to S3 (#1360)
Browse files Browse the repository at this point in the history
* Adds flag to S3Config that limits the number of connections to S3,
default to 25 (recommended in the
[aws-cpp-sdk](https://github.com/aws/aws-sdk-cpp/blob/c859aa7a1d85c5137f0416e4aa75b86ceb32e216/docs/ClientConfiguration_Parameters.md?plain=1#L46))
  • Loading branch information
samster25 committed Sep 10, 2023
1 parent 5d1e2df commit c8ffd03
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 14 deletions.
9 changes: 9 additions & 0 deletions src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::config;
/// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS
/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment
/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment
/// max_connections: Maximum number of connections to S3 at any time, defaults to 25
/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials
/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms
/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds
Expand Down Expand Up @@ -142,6 +143,7 @@ impl S3Config {
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
max_connections: Option<u32>,
retry_initial_backoff_ms: Option<u64>,
connect_timeout_ms: Option<u64>,
read_timeout_ms: Option<u64>,
Expand All @@ -157,6 +159,7 @@ impl S3Config {
key_id: key_id.or(def.key_id),
session_token: session_token.or(def.session_token),
access_key: access_key.or(def.access_key),
max_connections: max_connections.unwrap_or(def.max_connections),
retry_initial_backoff_ms: retry_initial_backoff_ms
.unwrap_or(def.retry_initial_backoff_ms),
connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms),
Expand Down Expand Up @@ -202,6 +205,12 @@ impl S3Config {
Ok(self.config.access_key.clone())
}

/// AWS max connections
#[getter]
pub fn max_connections(&self) -> PyResult<u32> {
Ok(self.config.max_connections)
}

/// AWS Retry Initial Backoff Time in Milliseconds
#[getter]
pub fn retry_initial_backoff_ms(&self) -> PyResult<u64> {
Expand Down
4 changes: 4 additions & 0 deletions src/common/io-config/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct S3Config {
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub max_connections: u32,
pub retry_initial_backoff_ms: u64,
pub connect_timeout_ms: u64,
pub read_timeout_ms: u64,
Expand All @@ -27,6 +28,7 @@ impl Default for S3Config {
key_id: None,
session_token: None,
access_key: None,
max_connections: 25,
retry_initial_backoff_ms: 1000,
connect_timeout_ms: 60_000,
read_timeout_ms: 60_000,
Expand All @@ -47,6 +49,7 @@ impl Display for S3Config {
key_id: {:?}
session_token: {:?},
access_key: {:?}
max_connections: {},
retry_initial_backoff_ms: {},
connect_timeout_ms: {},
read_timeout_ms: {},
Expand All @@ -59,6 +62,7 @@ impl Display for S3Config {
self.session_token,
self.access_key,
self.retry_initial_backoff_ms,
self.max_connections,
self.connect_timeout_ms,
self.read_timeout_ms,
self.num_tries,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-io/src/azure_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl ObjectSource for AzureBlobSource {
.into_error(e)
.into()
});
Ok(GetResult::Stream(stream.boxed(), None))
Ok(GetResult::Stream(stream.boxed(), None, None))
}

async fn get_size(&self, uri: &str) -> super::Result<usize> {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl GCSClientWrapper {
.into_error(e)
.into()
});
Ok(GetResult::Stream(response.boxed(), size))
Ok(GetResult::Stream(response.boxed(), size, None))
}
GCSClientWrapper::S3Compat(client) => {
let uri = format!("s3://{}/{}", bucket, key);
Expand Down
2 changes: 1 addition & 1 deletion src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl ObjectSource for HttpSource {
.into_error(e)
.into()
});
Ok(GetResult::Stream(stream.boxed(), size_bytes))
Ok(GetResult::Stream(stream.boxed(), size_bytes, None))
}

async fn get_size(&self, uri: &str) -> super::Result<usize> {
Expand Down
9 changes: 7 additions & 2 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::{BoxStream, Stream};
use futures::StreamExt;
use tokio::sync::OwnedSemaphorePermit;

use crate::local::{collect_file, LocalFile};

pub enum GetResult {
File(LocalFile),
Stream(BoxStream<'static, super::Result<Bytes>>, Option<usize>),
Stream(
BoxStream<'static, super::Result<Bytes>>,
Option<usize>,
Option<OwnedSemaphorePermit>,
),
}

async fn collect_bytes<S>(mut stream: S, size_hint: Option<usize>) -> super::Result<Bytes>
Expand Down Expand Up @@ -40,7 +45,7 @@ impl GetResult {
use GetResult::*;
match self {
File(f) => collect_file(f).await,
Stream(stream, size) => collect_bytes(stream, size).await,
Stream(stream, size, _permit) => collect_bytes(stream, size).await,
}
}
}
Expand Down
63 changes: 54 additions & 9 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use aws_smithy_async::rt::sleep::TokioSleep;
use reqwest::StatusCode;
use s3::operation::head_object::HeadObjectError;
use s3::operation::list_objects_v2::ListObjectsV2Error;
use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit};

use crate::object_io::{FileMetadata, FileType, LSResult};
use crate::{InvalidArgumentSnafu, SourceType};
Expand Down Expand Up @@ -34,6 +35,7 @@ use std::sync::Arc;
use std::time::Duration;
pub(crate) struct S3LikeSource {
region_to_client_map: tokio::sync::RwLock<HashMap<Region, Arc<s3::Client>>>,
connection_pool_sema: Arc<tokio::sync::Semaphore>,
default_region: Region,
s3_config: S3Config,
anonymous: bool,
Expand Down Expand Up @@ -91,6 +93,9 @@ enum Error {
#[snafu(display("Unable to create http client. {}", source))]
UnableToCreateClient { source: reqwest::Error },

#[snafu(display("Unable to grab semaphore. {}", source))]
UnableToGrabSemaphore { source: tokio::sync::AcquireError },

#[snafu(display(
"Unable to parse data as Utf8 while reading header for file: {path}. {source}"
))]
Expand Down Expand Up @@ -252,6 +257,9 @@ async fn build_client(config: &S3Config) -> super::Result<S3LikeSource> {
client_map.insert(default_region.clone(), client.into());
Ok(S3LikeSource {
region_to_client_map: tokio::sync::RwLock::new(client_map),
connection_pool_sema: Arc::new(tokio::sync::Semaphore::new(
config.max_connections as usize,
)),
s3_config: config.clone(),
default_region,
anonymous,
Expand Down Expand Up @@ -290,6 +298,7 @@ impl S3LikeSource {
#[async_recursion]
async fn _get_impl(
&self,
permit: OwnedSemaphorePermit,
uri: &str,
range: Option<Range<usize>>,
region: &Region,
Expand Down Expand Up @@ -357,7 +366,11 @@ impl S3LikeSource {
.into()
})
.boxed();
Ok(GetResult::Stream(stream, Some(v.content_length as usize)))
Ok(GetResult::Stream(
stream,
Some(v.content_length as usize),
Some(permit),
))
}

Err(SdkError::ServiceError(err)) => {
Expand All @@ -378,7 +391,7 @@ impl S3LikeSource {

let new_region = Region::new(region_name);
log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting GET in that region with new client", new_region, region);
self._get_impl(uri, range, &new_region).await
self._get_impl(permit, uri, range, &new_region).await
}
_ => Err(UnableToOpenFileSnafu { path: uri }
.into_error(SdkError::ServiceError(err))
Expand All @@ -393,7 +406,12 @@ impl S3LikeSource {
}

#[async_recursion]
async fn _head_impl(&self, uri: &str, region: &Region) -> super::Result<usize> {
async fn _head_impl(
&self,
_permit: SemaphorePermit<'async_recursion>,
uri: &str,
region: &Region,
) -> super::Result<usize> {
let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;

let bucket = match parsed.host_str() {
Expand Down Expand Up @@ -456,7 +474,7 @@ impl S3LikeSource {

let new_region = Region::new(region_name);
log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting HEAD in that region with new client", new_region, region);
self._head_impl(uri, &new_region).await
self._head_impl(_permit, uri, &new_region).await
}
_ => Err(UnableToHeadFileSnafu { path: uri }
.into_error(SdkError::ServiceError(err))
Expand All @@ -473,6 +491,7 @@ impl S3LikeSource {
#[async_recursion]
async fn _list_impl(
&self,
_permit: SemaphorePermit<'async_recursion>,
bucket: &str,
key: &str,
delimiter: String,
Expand Down Expand Up @@ -572,6 +591,7 @@ impl S3LikeSource {
let new_region = Region::new(region_name);
log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region);
self._list_impl(
_permit,
bucket,
key,
delimiter,
Expand All @@ -595,11 +615,23 @@ impl S3LikeSource {
#[async_trait]
impl ObjectSource for S3LikeSource {
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult> {
self._get_impl(uri, range, &self.default_region).await
let permit = self
.connection_pool_sema
.clone()
.acquire_owned()
.await
.context(UnableToGrabSemaphoreSnafu)?;
self._get_impl(permit, uri, range, &self.default_region)
.await
}

async fn get_size(&self, uri: &str) -> super::Result<usize> {
self._head_impl(uri, &self.default_region).await
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
self._head_impl(permit, uri, &self.default_region).await
}
async fn ls(
&self,
Expand All @@ -621,21 +653,34 @@ impl ObjectSource for S3LikeSource {
if let Some(key) = key.strip_prefix('/') {
// assume its a directory first
let key = format!("{}/", key.strip_suffix('/').unwrap_or(key));
let lsr = self
._list_impl(
let lsr = {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
self._list_impl(
permit,
bucket,
&key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
)
.await?;
.await?
};
if lsr.files.is_empty() && key.contains('/') {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
// Might be a File
let split = key.rsplit_once('/');
let (new_key, _) = split.unwrap();
let mut lsr = self
._list_impl(
permit,
bucket,
new_key,
delimiter.into(),
Expand Down

0 comments on commit c8ffd03

Please sign in to comment.