Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

provide ability to cache boto client instances directly and on S3Bucket #369

Merged
merged 33 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c584ed7
move s3_client up
mattiamatrix Dec 19, 2023
70ba4b7
get s3 client
mattiamatrix Dec 19, 2023
2389cbe
add print for testing
mattiamatrix Dec 19, 2023
4a74401
revert changes
mattiamatrix Dec 19, 2023
362c97f
update client
mattiamatrix Dec 19, 2023
5feb200
remove print
mattiamatrix Dec 22, 2023
43fad43
add @lru_cache
mattiamatrix Dec 22, 2023
1df6cb3
Add docs
mattiamatrix Dec 22, 2023
e23c67b
update docs
mattiamatrix Dec 22, 2023
7428ae3
revert changes
mattiamatrix Dec 22, 2023
a1a8866
fix docs
mattiamatrix Dec 22, 2023
dd0cca4
Update CHANGELOG.md
mattiamatrix Dec 22, 2023
bcf2bed
Add maxsize and typed=True
mattiamatrix Dec 22, 2023
4ec6a1e
Merge branch 'main' into use-s3-client-more-efficiently
zzstoatzz Jan 3, 2024
ab03c45
add test
mattiamatrix Jan 3, 2024
92e5d72
Merge branch 'main' into use-s3-client-more-efficiently
mattiamatrix Jan 3, 2024
59e38d1
Test with cache_info
mattiamatrix Jan 3, 2024
3cd1eab
Update AwsCredentials
mattiamatrix Jan 3, 2024
e0a927b
Only S3
mattiamatrix Jan 3, 2024
0de61f8
Empty-Commit
mattiamatrix Jan 3, 2024
dd09cb1
Merge branch 'main' into use-s3-client-more-efficiently
zzstoatzz Jan 10, 2024
283050e
Update hash function
mattiamatrix Jan 18, 2024
3200967
Revert changes
mattiamatrix Jan 18, 2024
e8b61e0
Update hash
mattiamatrix Jan 18, 2024
56210f8
Test different hash
mattiamatrix Jan 18, 2024
865975c
avoid modifying default behavior
zzstoatzz Jan 18, 2024
9b6cd9b
run pre-commits
zzstoatzz Jan 18, 2024
1f9c3d0
no way
zzstoatzz Jan 18, 2024
7202545
test caching via s3
zzstoatzz Jan 18, 2024
5aad7ab
Update prefect_aws/s3.py
zzstoatzz Jan 18, 2024
8379d3d
caching by default, remove toggle
zzstoatzz Jan 19, 2024
95df7a0
region
zzstoatzz Jan 19, 2024
1fde31e
check hashing on both creds classes
zzstoatzz Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361)

### Fixed

### Deprecated
Expand Down Expand Up @@ -105,6 +107,7 @@ Released August 31st, 2023.
Released July 20th, 2023.

### Changed

- Promoted workers to GA, removed beta disclaimers

## 0.3.5
Expand Down Expand Up @@ -293,6 +296,7 @@ Released on October 28th, 2022.
- `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137)

### Fixed

- Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139)
- Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143)

Expand Down
12 changes: 12 additions & 0 deletions prefect_aws/client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class AwsClientParameters(BaseModel):
title="Botocore Config",
)

def __hash__(self):
return hash(
(
self.api_version,
self.use_ssl,
self.verify,
self.verify_cert_path,
self.endpoint_url,
self.config,
)
)

@validator("config", pre=True)
def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]:
"""
Expand Down
76 changes: 66 additions & 10 deletions prefect_aws/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module handling AWS credentials"""

from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import Any, Optional, Union

import boto3
Expand All @@ -16,14 +18,43 @@

from prefect_aws.client_parameters import AwsClientParameters

_LOCK = Lock()


class ClientType(Enum):
"""The supported boto3 clients."""

S3 = "s3"
ECS = "ecs"
BATCH = "batch"
SECRETS_MANAGER = "secretsmanager"


@lru_cache(maxsize=8, typed=True)
def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any:
"""
Helper method to cache and dynamically get a client type.

Args:
client_type: The client's service name.

Returns:
An authenticated client.

Raises:
ValueError: if the client is not supported.
"""
with _LOCK:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = ctx.get_boto3_session().client(
service_name=client_type,
**ctx.aws_client_parameters.get_params_override(),
)
return client


class AwsCredentials(CredentialsBlock):
"""
Block used to manage authentication with AWS. AWS authentication is
Expand Down Expand Up @@ -75,6 +106,22 @@ class AwsCredentials(CredentialsBlock):
title="AWS Client Parameters",
)

class Config:
"""Config class for pydantic model."""

arbitrary_types_allowed = True

def __hash__(self):
field_hashes = (
hash(self.aws_access_key_id),
hash(self.aws_secret_access_key),
hash(self.aws_session_token),
hash(self.profile_name),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
)
return hash(field_hashes)

def get_boto3_session(self) -> boto3.Session:
"""
Returns an authenticated boto3 session that can be used to create clients
Expand Down Expand Up @@ -104,7 +151,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType]):
"""
Helper method to dynamically get a client type.

Expand All @@ -120,10 +167,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = self.get_boto3_session().client(
service_name=client_type, **self.aws_client_parameters.get_params_override()
)
return client
return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
"""
Expand Down Expand Up @@ -186,6 +230,21 @@ class MinIOCredentials(CredentialsBlock):
description="Extra parameters to initialize the Client.",
)

class Config:
"""Config class for pydantic model."""

arbitrary_types_allowed = True

def __hash__(self):
return hash(
(
hash(self.minio_root_user),
hash(self.minio_root_password),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
)
)

def get_boto3_session(self) -> boto3.Session:
"""
Returns an authenticated boto3 session that can be used to create clients
Expand Down Expand Up @@ -218,7 +277,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType]):
"""
Helper method to dynamically get a client type.

Expand All @@ -234,10 +293,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = self.get_boto3_session().client(
service_name=client_type, **self.aws_client_parameters.get_params_override()
)
return client
return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
"""
Expand Down
2 changes: 1 addition & 1 deletion prefect_aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _get_s3_client(self) -> boto3.client:
Authenticate MinIO credentials or AWS credentials and return an S3 client.
This is a helper function called by read_path() or write_path().
"""
return self.credentials.get_s3_client()
return self.credentials.get_client("s3")

def _get_bucket_resource(self) -> boto3.resource:
"""
Expand Down
122 changes: 121 additions & 1 deletion tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from botocore.client import BaseClient
from moto import mock_s3

from prefect_aws.credentials import AwsCredentials, ClientType, MinIOCredentials
from prefect_aws.credentials import (
AwsCredentials,
ClientType,
MinIOCredentials,
_get_client_cached,
)


def test_aws_credentials_get_boto3_session():
Expand Down Expand Up @@ -44,3 +49,118 @@ def test_minio_credentials_get_boto3_session():
def test_credentials_get_client(credentials, client_type):
with mock_s3():
assert isinstance(credentials.get_client(client_type), BaseClient)


@pytest.mark.parametrize(
"credentials",
[
AwsCredentials(region_name="us-east-1"),
MinIOCredentials(
minio_root_user="root_user",
minio_root_password="root_password",
region_name="us-east-1",
),
],
)
@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_get_client_cached(credentials, client_type):
"""
Test to ensure that _get_client_cached function returns the same instance
for multiple calls with the same parameters and properly utilizes lru_cache.
"""

_get_client_cached.cache_clear()

assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0"

credentials.get_client(client_type)
credentials.get_client(client_type)
credentials.get_client(client_type)

assert _get_client_cached.cache_info().misses == 1
assert _get_client_cached.cache_info().hits == 2


@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_aws_credentials_change_causes_cache_miss(client_type):
"""
Test to ensure that changing configuration on an AwsCredentials instance
after fetching a client causes a cache miss.
"""

_get_client_cached.cache_clear()

credentials = AwsCredentials(region_name="us-east-1")

initial_client = credentials.get_client(client_type)

credentials.region_name = "us-west-2"

new_client = credentials.get_client(client_type)

assert (
initial_client is not new_client
), "Client should be different after configuration change"

assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice"


@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_minio_credentials_change_causes_cache_miss(client_type):
"""
Test to ensure that changing configuration on an AwsCredentials instance
after fetching a client causes a cache miss.
"""

_get_client_cached.cache_clear()

credentials = MinIOCredentials(
minio_root_user="root_user",
minio_root_password="root_password",
region_name="us-east-1",
)

initial_client = credentials.get_client(client_type)

credentials.region_name = "us-west-2"

new_client = credentials.get_client(client_type)

assert (
initial_client is not new_client
), "Client should be different after configuration change"

assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice"


@pytest.mark.parametrize(
"credentials_type, initial_field, new_field",
[
(
AwsCredentials,
{"region_name": "us-east-1"},
{"region_name": "us-east-2"},
),
(
MinIOCredentials,
{
"region_name": "us-east-1",
"minio_root_user": "root_user",
"minio_root_password": "root_password",
},
{
"region_name": "us-east-2",
"minio_root_user": "root_user",
"minio_root_password": "root_password",
},
),
],
)
def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field):
credentials = credentials_type(**initial_field)
initial_hash = hash(credentials)

setattr(credentials, list(new_field.keys())[0], list(new_field.values())[0])
new_hash = hash(credentials)

assert initial_hash != new_hash, "Hash should change when region_name changes"