diff --git a/CHANGELOG.md b/CHANGELOG.md index 64454b60..b40563f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) + ### Fixed ### Deprecated @@ -105,6 +107,7 @@ Released August 31st, 2023. Released July 20th, 2023. ### Changed + - Promoted workers to GA, removed beta disclaimers ## 0.3.5 @@ -293,6 +296,7 @@ Released on October 28th, 2022. - `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137) ### Fixed + - Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139) - Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143) diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index bf030590..eb3be09b 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -70,6 +70,18 @@ class AwsClientParameters(BaseModel): title="Botocore Config", ) + def __hash__(self): + return hash( + ( + self.api_version, + self.use_ssl, + self.verify, + self.verify_cert_path, + self.endpoint_url, + self.config, + ) + ) + @validator("config", pre=True) def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]: """ diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index 64f49efe..5aeddaa6 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -1,6 +1,8 @@ """Module handling AWS credentials""" from enum import Enum +from functools import lru_cache +from threading import Lock from typing import Any, Optional, Union import boto3 @@ -16,14 +18,43 @@ from prefect_aws.client_parameters import AwsClientParameters +_LOCK = Lock() + class ClientType(Enum): + """The supported boto3 clients.""" + S3 = "s3" ECS = "ecs" BATCH = "batch" SECRETS_MANAGER = "secretsmanager" +@lru_cache(maxsize=8, typed=True) +def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any: + """ + Helper method to cache and dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + with _LOCK: + if isinstance(client_type, ClientType): + client_type = client_type.value + + client = ctx.get_boto3_session().client( + service_name=client_type, + **ctx.aws_client_parameters.get_params_override(), + ) + return client + + class AwsCredentials(CredentialsBlock): """ Block used to manage authentication with AWS. AWS authentication is @@ -75,6 +106,22 @@ class AwsCredentials(CredentialsBlock): title="AWS Client Parameters", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + field_hashes = ( + hash(self.aws_access_key_id), + hash(self.aws_secret_access_key), + hash(self.aws_session_token), + hash(self.profile_name), + hash(self.region_name), + hash(frozenset(self.aws_client_parameters.dict().items())), + ) + return hash(field_hashes) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -104,7 +151,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType]): """ Helper method to dynamically get a client type. @@ -120,10 +167,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ @@ -186,6 +230,21 @@ class MinIOCredentials(CredentialsBlock): description="Extra parameters to initialize the Client.", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + return hash( + ( + hash(self.minio_root_user), + hash(self.minio_root_password), + hash(self.region_name), + hash(frozenset(self.aws_client_parameters.dict().items())), + ) + ) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -218,7 +277,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType]): """ Helper method to dynamically get a client type. @@ -234,10 +293,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 643d78ac..a10e2171 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -466,7 +466,7 @@ def _get_s3_client(self) -> boto3.client: Authenticate MinIO credentials or AWS credentials and return an S3 client. This is a helper function called by read_path() or write_path(). """ - return self.credentials.get_s3_client() + return self.credentials.get_client("s3") def _get_bucket_resource(self) -> boto3.resource: """ diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 6e0a1ff8..96ecbd22 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -3,7 +3,12 @@ from botocore.client import BaseClient from moto import mock_s3 -from prefect_aws.credentials import AwsCredentials, ClientType, MinIOCredentials +from prefect_aws.credentials import ( + AwsCredentials, + ClientType, + MinIOCredentials, + _get_client_cached, +) def test_aws_credentials_get_boto3_session(): @@ -44,3 +49,118 @@ def test_minio_credentials_get_boto3_session(): def test_credentials_get_client(credentials, client_type): with mock_s3(): assert isinstance(credentials.get_client(client_type), BaseClient) + + +@pytest.mark.parametrize( + "credentials", + [ + AwsCredentials(region_name="us-east-1"), + MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ), + ], +) +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_get_client_cached(credentials, client_type): + """ + Test to ensure that _get_client_cached function returns the same instance + for multiple calls with the same parameters and properly utilizes lru_cache. + """ + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + credentials.get_client(client_type) + credentials.get_client(client_type) + credentials.get_client(client_type) + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_aws_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = AwsCredentials(region_name="us-east-1") + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_minio_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ) + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize( + "credentials_type, initial_field, new_field", + [ + ( + AwsCredentials, + {"region_name": "us-east-1"}, + {"region_name": "us-east-2"}, + ), + ( + MinIOCredentials, + { + "region_name": "us-east-1", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + { + "region_name": "us-east-2", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + ), + ], +) +def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field): + credentials = credentials_type(**initial_field) + initial_hash = hash(credentials) + + setattr(credentials, list(new_field.keys())[0], list(new_field.values())[0]) + new_hash = hash(credentials) + + assert initial_hash != new_hash, "Hash should change when region_name changes"