Skip to content

Commit

Permalink
[FEAT] add session token as input to io config (#1224)
Browse files Browse the repository at this point in the history
* Adds `session_token` field to `S3Config` allowing users to pass temp
session tokens (like from sts)to daft.
* Adds IO integration tests to ensure that we can use the IOConfig to
access s3 with a signed request.
* Closes: #1216
  • Loading branch information
samster25 committed Aug 3, 2023
1 parent c1b39ba commit ad11d44
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 9 deletions.
9 changes: 8 additions & 1 deletion src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct S3Config {
pub region_name: Option<String>,
pub endpoint_url: Option<String>,
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub anonymous: bool,
}
Expand All @@ -25,9 +26,15 @@ impl Display for S3Config {
region_name: {:?}
endpoint_url: {:?}
key_id: {:?}
session_token: {:?},
access_key: {:?}
anonymous: {}",
self.region_name, self.endpoint_url, self.key_id, self.access_key, self.anonymous
self.region_name,
self.endpoint_url,
self.key_id,
self.session_token,
self.access_key,
self.anonymous
)
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl PyS3Config {
region_name: Option<String>,
endpoint_url: Option<String>,
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
anonymous: Option<bool>,
) -> Self {
Expand All @@ -63,6 +64,7 @@ impl PyS3Config {
region_name,
endpoint_url,
key_id,
session_token,
access_key,
anonymous: anonymous.unwrap_or(false),
},
Expand Down
4 changes: 2 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>

let builder = if config.access_key.is_some() && config.key_id.is_some() {
let creds = Credentials::from_keys(
config.access_key.clone().unwrap(),
config.key_id.clone().unwrap(),
None,
config.access_key.clone().unwrap(),
config.session_token.clone(),
);
builder.credentials_provider(creds)
} else if config.access_key.is_some() || config.key_id.is_some() {
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,9 @@ def minio_image_data_fixture(minio_io_config, image_data_folder) -> YieldFixture
"""Populates the minio session with some fake data and yields (S3Config, paths)"""
with mount_data_minio(minio_io_config, image_data_folder) as urls:
yield urls


@pytest.fixture(scope="session")
def small_images_s3_paths() -> list[str]:
"""Paths to small *.jpg files in a public S3 bucket"""
return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)]
39 changes: 39 additions & 0 deletions tests/integration/io/test_url_download_private_aws_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import pytest
from botocore import session

import daft
from daft.io import IOConfig, S3Config


@pytest.fixture(scope="session")
def io_config() -> IOConfig:
"""Create IOConfig with botocore's current session"""
sess = session.Session()
creds = sess.get_credentials()

return IOConfig(
s3=S3Config(
key_id=creds.access_key, access_key=creds.secret_key, session_token=creds.token, region_name="us-west-2"
)
)


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_with_creds(small_images_s3_paths, io_config):
data = {"urls": small_images_s3_paths}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(use_native_downloader=True, io_config=io_config))

data = df.to_pydict()
assert len(data["data"]) == 6
for img_bytes in data["data"]:
assert img_bytes is not None


@pytest.mark.integration()
def test_read_parquet_aws_s3_public_bucket_with_creds(io_config):
filename = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet"
df = daft.read_parquet(filename, io_config=io_config, use_native_downloader=True).collect()
assert len(df) == 100
6 changes: 0 additions & 6 deletions tests/integration/io/test_url_download_public_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
import daft


@pytest.fixture(scope="session")
def small_images_s3_paths() -> list[str]:
"""Paths to small *.jpg files in a public S3 bucket"""
return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)]


@pytest.mark.integration()
def test_url_download_aws_s3_public_bucket_custom_s3fs(small_images_s3_paths):
fs = s3fs.S3FileSystem(anon=True)
Expand Down

0 comments on commit ad11d44

Please sign in to comment.