diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 3f14834fb3..15179102e9 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -1,5 +1,6 @@ #![feature(async_closure)] #![feature(let_chains)] + mod azure_blob; mod google_cloud; mod http; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 4638077eb4..7681201fb6 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use aws_config::meta::credentials::CredentialsProviderChain; use aws_config::retry::RetryMode; use aws_config::timeout::TimeoutConfig; use aws_smithy_async::rt::sleep::TokioSleep; @@ -11,7 +12,7 @@ use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; -use aws_credential_types::cache::ProvideCachedCredentials; +use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache}; use aws_credential_types::provider::error::CredentialsError; use aws_sig_auth::signer::SigningRequirements; use common_io_config::S3Config; @@ -211,7 +212,10 @@ fn handle_https_client_settings( Ok(builder) } -async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> { +async fn build_s3_client( + config: &S3Config, + credentials_cache: Option, +) -> super::Result<(bool, s3::Client)> { const DEFAULT_REGION: Region = Region::from_static("us-east-1"); let mut anonymous = config.anonymous; @@ -228,8 +232,6 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> }; let builder = if let Some(region) = &config.region_name { builder.region(Region::new(region.to_owned())) - } else if conf.region().is_none() && config.region_name.is_none() { - builder.region(DEFAULT_REGION) } else { builder }; @@ -268,7 +270,19 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> .build(); let builder = builder.timeout_config(timeout_config); - let builder = if config.access_key.is_some() && config.key_id.is_some() { + let cached_creds = if let Some(credentials_cache) = credentials_cache { + let creds = credentials_cache.provide_cached_credentials().await; + creds.ok() + } else { + None + }; + + let builder = if let Some(cached_creds) = cached_creds { + let provider = CredentialsProviderChain::first_try("different_region_cache", cached_creds) + .or_default_provider() + .await; + builder.credentials_provider(provider) + } else if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( config.key_id.clone().unwrap(), config.access_key.clone().unwrap(), @@ -283,6 +297,7 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> builder }; + let builder_copy = builder.clone(); let s3_conf = builder.build(); if !config.anonymous { use CredentialsError::*; @@ -300,11 +315,16 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> }.with_context(|_| UnableToLoadCredentialsSnafu {})?; }; + let s3_conf = if s3_conf.region().is_none() { + builder_copy.region(DEFAULT_REGION).build() + } else { + s3_conf + }; Ok((anonymous, s3::Client::from_conf(s3_conf))) } async fn build_client(config: &S3Config) -> super::Result { - let (anonymous, client) = build_s3_client(config).await?; + let (anonymous, client) = build_s3_client(config, None).await?; let mut client_map = HashMap::new(); let default_region = client.conf().region().unwrap().clone(); client_map.insert(default_region.clone(), client.into()); @@ -343,7 +363,12 @@ impl S3LikeSource { let mut new_config = self.s3_config.clone(); new_config.region_name = Some(region.to_string()); - let (_, new_client) = build_s3_client(&new_config).await?; + + let creds_cache = w_handle + .get(&self.default_region) + .map(|current_client| current_client.conf().credentials_cache()); + + let (_, new_client) = build_s3_client(&new_config, creds_cache).await?; if w_handle.get(region).is_none() { w_handle.insert(region.clone(), new_client.into());