diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 62492283e4..342df3bf42 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -11,6 +11,7 @@ use crate::config; /// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS /// key_id: AWS Access Key ID, defaults to auto-detection from the current environment /// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment +/// max_connections: Maximum number of connections to S3 at any time, defaults to 25 /// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials /// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms /// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 60 seconds @@ -142,6 +143,7 @@ impl S3Config { key_id: Option, session_token: Option, access_key: Option, + max_connections: Option, retry_initial_backoff_ms: Option, connect_timeout_ms: Option, read_timeout_ms: Option, @@ -157,6 +159,7 @@ impl S3Config { key_id: key_id.or(def.key_id), session_token: session_token.or(def.session_token), access_key: access_key.or(def.access_key), + max_connections: max_connections.unwrap_or(def.max_connections), retry_initial_backoff_ms: retry_initial_backoff_ms .unwrap_or(def.retry_initial_backoff_ms), connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms), @@ -202,6 +205,12 @@ impl S3Config { Ok(self.config.access_key.clone()) } + /// AWS max connections + #[getter] + pub fn max_connections(&self) -> PyResult { + Ok(self.config.max_connections) + } + /// AWS Retry Initial Backoff Time in Milliseconds #[getter] pub fn retry_initial_backoff_ms(&self) -> PyResult { diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 57ba13c567..6cc309953c 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -11,6 +11,7 @@ pub struct S3Config { pub key_id: Option, pub session_token: Option, pub access_key: Option, + pub max_connections: u32, pub retry_initial_backoff_ms: u64, pub connect_timeout_ms: u64, pub read_timeout_ms: u64, @@ -27,6 +28,7 @@ impl Default for S3Config { key_id: None, session_token: None, access_key: None, + max_connections: 25, retry_initial_backoff_ms: 1000, connect_timeout_ms: 60_000, read_timeout_ms: 60_000, @@ -47,6 +49,7 @@ impl Display for S3Config { key_id: {:?} session_token: {:?}, access_key: {:?} + max_connections: {}, retry_initial_backoff_ms: {}, connect_timeout_ms: {}, read_timeout_ms: {}, @@ -59,6 +62,7 @@ impl Display for S3Config { self.session_token, self.access_key, self.retry_initial_backoff_ms, + self.max_connections, self.connect_timeout_ms, self.read_timeout_ms, self.num_tries, diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index d7ef6cd37c..ee62da1525 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -152,7 +152,7 @@ impl ObjectSource for AzureBlobSource { .into_error(e) .into() }); - Ok(GetResult::Stream(stream.boxed(), None)) + Ok(GetResult::Stream(stream.boxed(), None, None)) } async fn get_size(&self, uri: &str) -> super::Result { diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index 9d91f32290..04a6f45d2a 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -146,7 +146,7 @@ impl GCSClientWrapper { .into_error(e) .into() }); - Ok(GetResult::Stream(response.boxed(), size)) + Ok(GetResult::Stream(response.boxed(), size, None)) } GCSClientWrapper::S3Compat(client) => { let uri = format!("s3://{}/{}", bucket, key); diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 5820b104ef..a98728d91b 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -119,7 +119,7 @@ impl ObjectSource for HttpSource { .into_error(e) .into() }); - Ok(GetResult::Stream(stream.boxed(), size_bytes)) + Ok(GetResult::Stream(stream.boxed(), size_bytes, None)) } async fn get_size(&self, uri: &str) -> super::Result { diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index ae2f480182..950438a37b 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -4,12 +4,17 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; +use tokio::sync::OwnedSemaphorePermit; use crate::local::{collect_file, LocalFile}; pub enum GetResult { File(LocalFile), - Stream(BoxStream<'static, super::Result>, Option), + Stream( + BoxStream<'static, super::Result>, + Option, + Option, + ), } async fn collect_bytes(mut stream: S, size_hint: Option) -> super::Result @@ -40,7 +45,7 @@ impl GetResult { use GetResult::*; match self { File(f) => collect_file(f).await, - Stream(stream, size) => collect_bytes(stream, size).await, + Stream(stream, size, _permit) => collect_bytes(stream, size).await, } } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 91c9d6d82e..14a09916d3 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -5,6 +5,7 @@ use aws_smithy_async::rt::sleep::TokioSleep; use reqwest::StatusCode; use s3::operation::head_object::HeadObjectError; use s3::operation::list_objects_v2::ListObjectsV2Error; +use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; use crate::{InvalidArgumentSnafu, SourceType}; @@ -34,6 +35,7 @@ use std::sync::Arc; use std::time::Duration; pub(crate) struct S3LikeSource { region_to_client_map: tokio::sync::RwLock>>, + connection_pool_sema: Arc, default_region: Region, s3_config: S3Config, anonymous: bool, @@ -91,6 +93,9 @@ enum Error { #[snafu(display("Unable to create http client. {}", source))] UnableToCreateClient { source: reqwest::Error }, + #[snafu(display("Unable to grab semaphore. {}", source))] + UnableToGrabSemaphore { source: tokio::sync::AcquireError }, + #[snafu(display( "Unable to parse data as Utf8 while reading header for file: {path}. {source}" ))] @@ -252,6 +257,9 @@ async fn build_client(config: &S3Config) -> super::Result { client_map.insert(default_region.clone(), client.into()); Ok(S3LikeSource { region_to_client_map: tokio::sync::RwLock::new(client_map), + connection_pool_sema: Arc::new(tokio::sync::Semaphore::new( + config.max_connections as usize, + )), s3_config: config.clone(), default_region, anonymous, @@ -290,6 +298,7 @@ impl S3LikeSource { #[async_recursion] async fn _get_impl( &self, + permit: OwnedSemaphorePermit, uri: &str, range: Option>, region: &Region, @@ -357,7 +366,11 @@ impl S3LikeSource { .into() }) .boxed(); - Ok(GetResult::Stream(stream, Some(v.content_length as usize))) + Ok(GetResult::Stream( + stream, + Some(v.content_length as usize), + Some(permit), + )) } Err(SdkError::ServiceError(err)) => { @@ -378,7 +391,7 @@ impl S3LikeSource { let new_region = Region::new(region_name); log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting GET in that region with new client", new_region, region); - self._get_impl(uri, range, &new_region).await + self._get_impl(permit, uri, range, &new_region).await } _ => Err(UnableToOpenFileSnafu { path: uri } .into_error(SdkError::ServiceError(err)) @@ -393,7 +406,12 @@ impl S3LikeSource { } #[async_recursion] - async fn _head_impl(&self, uri: &str, region: &Region) -> super::Result { + async fn _head_impl( + &self, + _permit: SemaphorePermit<'async_recursion>, + uri: &str, + region: &Region, + ) -> super::Result { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let bucket = match parsed.host_str() { @@ -456,7 +474,7 @@ impl S3LikeSource { let new_region = Region::new(region_name); log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting HEAD in that region with new client", new_region, region); - self._head_impl(uri, &new_region).await + self._head_impl(_permit, uri, &new_region).await } _ => Err(UnableToHeadFileSnafu { path: uri } .into_error(SdkError::ServiceError(err)) @@ -473,6 +491,7 @@ impl S3LikeSource { #[async_recursion] async fn _list_impl( &self, + _permit: SemaphorePermit<'async_recursion>, bucket: &str, key: &str, delimiter: String, @@ -572,6 +591,7 @@ impl S3LikeSource { let new_region = Region::new(region_name); log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( + _permit, bucket, key, delimiter, @@ -595,11 +615,23 @@ impl S3LikeSource { #[async_trait] impl ObjectSource for S3LikeSource { async fn get(&self, uri: &str, range: Option>) -> super::Result { - self._get_impl(uri, range, &self.default_region).await + let permit = self + .connection_pool_sema + .clone() + .acquire_owned() + .await + .context(UnableToGrabSemaphoreSnafu)?; + self._get_impl(permit, uri, range, &self.default_region) + .await } async fn get_size(&self, uri: &str) -> super::Result { - self._head_impl(uri, &self.default_region).await + let permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + self._head_impl(permit, uri, &self.default_region).await } async fn ls( &self, @@ -621,21 +653,34 @@ impl ObjectSource for S3LikeSource { if let Some(key) = key.strip_prefix('/') { // assume its a directory first let key = format!("{}/", key.strip_suffix('/').unwrap_or(key)); - let lsr = self - ._list_impl( + let lsr = { + let permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + self._list_impl( + permit, bucket, &key, delimiter.into(), continuation_token.map(String::from), &self.default_region, ) - .await?; + .await? + }; if lsr.files.is_empty() && key.contains('/') { + let permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; // Might be a File let split = key.rsplit_once('/'); let (new_key, _) = split.unwrap(); let mut lsr = self ._list_impl( + permit, bucket, new_key, delimiter.into(),