Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix s3 session creation in deployment steps push_to_s3 and pull_from_…
Browse files Browse the repository at this point in the history
…s3 (#322)

Co-authored-by: Alexander Streed <[email protected]>
  • Loading branch information
markbruning and desertaxle authored Oct 11, 2023
1 parent 9ba5424 commit ce851e0
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314)
- Changed `push_to_s3` and `pull_from_s3` deployment steps to properly create a boto3 session client if the passed credentials are a referenced AwsCredentials block [#322](https://github.com/PrefectHQ/prefect-aws/pull/322)

### Fixed

Expand Down
68 changes: 51 additions & 17 deletions prefect_aws/deployments/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,7 @@ def push_to_s3(
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})
client = boto3.client(
"s3", **credentials, **client_parameters, config=Config(**advanced_config)
)
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -117,7 +110,7 @@ def push_to_s3(
continue
elif not local_file_path.is_dir():
remote_file_path = Path(folder) / local_file_path.relative_to(local_path)
client.upload_file(
s3.upload_file(
str(local_file_path), bucket, str(remote_file_path.as_posix())
)

Expand Down Expand Up @@ -174,14 +167,7 @@ def pull_from_s3(
credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}"
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})

session = boto3.Session(**credentials)
s3 = session.client("s3", **client_parameters, config=Config(**advanced_config))
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -206,3 +192,51 @@ def pull_from_s3(
"folder": folder,
"directory": str(local_path),
}


def get_s3_client(
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
) -> dict:
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}

# Get credentials from credentials (regardless if block or not)
aws_access_key_id = credentials.get("aws_access_key_id", None)
aws_secret_access_key = credentials.get("aws_secret_access_key", None)
aws_session_token = credentials.get("aws_session_token", None)

# Get remaining session info from credentials, or client_parameters
profile_name = credentials.get(
"profile_name", client_parameters.get("profile_name", None)
)
region_name = credentials.get(
"region_name", client_parameters.get("region_name", None)
)

# Get additional info from client_parameters, otherwise credentials input (if block)
aws_client_parameters = credentials.get("aws_client_parameters", client_parameters)
api_version = aws_client_parameters.get("api_version", None)
endpoint_url = aws_client_parameters.get("endpoint_url", None)
use_ssl = aws_client_parameters.get("use_ssl", None)
verify = aws_client_parameters.get("verify", None)
config_params = aws_client_parameters.get("config", {})
config = Config(**config_params)

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=profile_name,
region_name=region_name,
)
return session.client(
"s3",
api_version=api_version,
endpoint_url=endpoint_url,
use_ssl=use_ssl,
verify=verify,
config=config,
)
96 changes: 95 additions & 1 deletion tests/deploments/test_steps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import sys
from pathlib import Path, PurePath, PurePosixPath
from unittest.mock import patch

import boto3
import pytest
from moto import mock_s3

from prefect_aws.deployments.steps import pull_from_s3, push_to_s3
from prefect_aws import AwsCredentials
from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3


@pytest.fixture
Expand Down Expand Up @@ -173,6 +175,98 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials):
assert not (tmp_path / "empty2_copy").exists()


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8+")
def test_s3_session_with_params():
with patch("boto3.Session") as mock_session:
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"profile_name": "foo",
"region_name": "us-weast-1",
"aws_client_parameters": {
"api_version": "v1",
"config": {"connect_timeout": 300},
},
}
)
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
},
client_parameters={
"region_name": "us-west-1",
"config": {"signature_version": "s3v4"},
},
)
creds_block = AwsCredentials(
aws_access_key_id="BlockKey",
aws_secret_access_key="BlockSecret",
aws_session_token="BlockToken",
profile_name="BlockProfile",
region_name="BlockRegion",
aws_client_parameters={
"api_version": "v1",
"use_ssl": True,
"verify": True,
"endpoint_url": "BlockEndpoint",
"config": {"connect_timeout": 123},
},
)
get_s3_client(credentials=creds_block.dict())
all_calls = mock_session.mock_calls
assert len(all_calls) == 6
assert all_calls[0].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": "foo",
"region_name": "us-weast-1",
}
assert all_calls[1].args[0] == "s3"
assert {
"api_version": "v1",
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[1].kwargs.items()
assert all_calls[1].kwargs.get("config").connect_timeout == 300
assert all_calls[1].kwargs.get("config").signature_version is None
assert all_calls[2].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": None,
"region_name": "us-west-1",
}
assert all_calls[3].args[0] == "s3"
assert {
"api_version": None,
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[3].kwargs.items()
assert all_calls[3].kwargs.get("config").connect_timeout == 60
assert all_calls[3].kwargs.get("config").signature_version == "s3v4"
assert all_calls[4].kwargs == {
"aws_access_key_id": "BlockKey",
"aws_secret_access_key": creds_block.aws_secret_access_key,
"aws_session_token": "BlockToken",
"profile_name": "BlockProfile",
"region_name": "BlockRegion",
}
assert all_calls[5].args[0] == "s3"
assert {
"api_version": "v1",
"use_ssl": True,
"verify": True,
"endpoint_url": "BlockEndpoint",
}.items() <= all_calls[5].kwargs.items()
assert all_calls[5].kwargs.get("config").connect_timeout == 123
assert all_calls[5].kwargs.get("config").signature_version is None


def test_custom_credentials_and_client_parameters(s3_setup, tmp_files):
s3, bucket_name = s3_setup
folder = "my-project"
Expand Down

0 comments on commit ce851e0

Please sign in to comment.