diff --git a/src/daft-io/src/config.rs b/src/daft-io/src/config.rs index 367715b935..f5529a4165 100644 --- a/src/daft-io/src/config.rs +++ b/src/daft-io/src/config.rs @@ -8,6 +8,7 @@ pub struct S3Config { pub region_name: Option, pub endpoint_url: Option, pub key_id: Option, + pub session_token: Option, pub access_key: Option, pub anonymous: bool, } @@ -25,9 +26,15 @@ impl Display for S3Config { region_name: {:?} endpoint_url: {:?} key_id: {:?} + session_token: {:?}, access_key: {:?} anonymous: {}", - self.region_name, self.endpoint_url, self.key_id, self.access_key, self.anonymous + self.region_name, + self.endpoint_url, + self.key_id, + self.session_token, + self.access_key, + self.anonymous ) } } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 07383dd6b1..0ad25f300a 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -55,6 +55,7 @@ impl PyS3Config { region_name: Option, endpoint_url: Option, key_id: Option, + session_token: Option, access_key: Option, anonymous: Option, ) -> Self { @@ -63,6 +64,7 @@ impl PyS3Config { region_name, endpoint_url, key_id, + session_token, access_key, anonymous: anonymous.unwrap_or(false), }, diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index e96bfe394e..df58031555 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -142,9 +142,9 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> let builder = if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( - config.access_key.clone().unwrap(), config.key_id.clone().unwrap(), - None, + config.access_key.clone().unwrap(), + config.session_token.clone(), ); builder.credentials_provider(creds) } else if config.access_key.is_some() || config.key_id.is_some() { diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index ddf09ae17e..bd3ba50518 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -166,3 +166,9 @@ def minio_image_data_fixture(minio_io_config, image_data_folder) -> YieldFixture """Populates the minio session with some fake data and yields (S3Config, paths)""" with mount_data_minio(minio_io_config, image_data_folder) as urls: yield urls + + +@pytest.fixture(scope="session") +def small_images_s3_paths() -> list[str]: + """Paths to small *.jpg files in a public S3 bucket""" + return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)] diff --git a/tests/integration/io/test_url_download_private_aws_s3.py b/tests/integration/io/test_url_download_private_aws_s3.py new file mode 100644 index 0000000000..1635ddfb16 --- /dev/null +++ b/tests/integration/io/test_url_download_private_aws_s3.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import pytest +from botocore import session + +import daft +from daft.io import IOConfig, S3Config + + +@pytest.fixture(scope="session") +def io_config() -> IOConfig: + """Create IOConfig with botocore's current session""" + sess = session.Session() + creds = sess.get_credentials() + + return IOConfig( + s3=S3Config( + key_id=creds.access_key, access_key=creds.secret_key, session_token=creds.token, region_name="us-west-2" + ) + ) + + +@pytest.mark.integration() +def test_url_download_aws_s3_public_bucket_with_creds(small_images_s3_paths, io_config): + data = {"urls": small_images_s3_paths} + df = daft.from_pydict(data) + df = df.with_column("data", df["urls"].url.download(use_native_downloader=True, io_config=io_config)) + + data = df.to_pydict() + assert len(data["data"]) == 6 + for img_bytes in data["data"]: + assert img_bytes is not None + + +@pytest.mark.integration() +def test_read_parquet_aws_s3_public_bucket_with_creds(io_config): + filename = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" + df = daft.read_parquet(filename, io_config=io_config, use_native_downloader=True).collect() + assert len(df) == 100 diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index ed84ff1a8c..8aeccb6b8b 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -6,12 +6,6 @@ import daft -@pytest.fixture(scope="session") -def small_images_s3_paths() -> list[str]: - """Paths to small *.jpg files in a public S3 bucket""" - return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)] - - @pytest.mark.integration() def test_url_download_aws_s3_public_bucket_custom_s3fs(small_images_s3_paths): fs = s3fs.S3FileSystem(anon=True)