diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cb1465e9..4b31753b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.8 - name: Install packages run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a52bfdeb..c7a7acea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,10 +9,10 @@ jobs: strategy: matrix: python-version: - - "3.7" - "3.8" - "3.9" - "3.10" + - "3.11" fail-fast: false steps: - uses: actions/checkout@v4 @@ -33,9 +33,7 @@ jobs: env: PREFECT_SERVER_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///./collection-tests.db" run: | - prefect server database reset -y - coverage run --branch -m pytest tests -vv - coverage report + pytest --cov=prefect_aws --no-cov-on-fail --cov-report=term-missing --cov-branch tests -n auto -vv - name: Run mkdocs build run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cee4940..d4e01351 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,15 +9,54 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303) -- Added support to `ECSWorker` for `awsvpcConfiguration` [#304](https://github.com/PrefectHQ/prefect-aws/pull/304) - ### Changed -- Added a check to AwsSecretsManager.load() for 'SecretString' as well as 'SecretBinary' - [#274](https://github.com/PrefectHQ/prefect-aws/pull/274) +- Added 'SecretBrinary' suport to `AwsSecret` block - [#274](https://github.com/PrefectHQ/prefect-aws/pull/274) + +### Fixed + ### Deprecated ### Removed +## 0.4.2 + +Released November 6th, 2023. + +### Fixed + +- Fixed use_ssl default for s3 client. + +## 0.4.1 + +Released October 13th, 2023. + +### Added + +- AWS S3 copy and move tasks and `S3Bucket` methods - [#316](https://github.com/PrefectHQ/prefect-aws/pull/316) + +### Fixed + +- `ECSWorker` issue where defining a custom network configuration with a subnet would erroneously report it as missing from the VPC when more than one subnet exists in the VPC. [#321](https://github.com/PrefectHQ/prefect-aws/pull/321) +- Updated `push_to_s3` and `pull_from_s3` deployment steps to properly create a boto3 session client if the passed credentials are a referenced `AwsCredentials` block [#322](https://github.com/PrefectHQ/prefect-aws/pull/322) + +## 0.4.0 + +Released October 5th, 2023. + +### Changed + +- Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314) +- Conditional imports to support operating with pydantic>2 installed - [#317](https://github.com/PrefectHQ/prefect-aws/pull/317) + +## 0.3.7 + +Released August 31st, 2023. + +### Added + +- Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303) +- Added support to `ECSWorker` for `awsvpcConfiguration` [#304](https://github.com/PrefectHQ/prefect-aws/pull/304) + ## 0.3.6 Released July 20th, 2023. diff --git a/README.md b/README.md index 596765bd..952f8fd2 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,24 @@ -# Incorporate AWS into your Prefect workflows with `prefect-aws` +# `prefect-aws`
## Welcome! -The `prefect-aws` collection makes it easy to leverage the capabilities of AWS in your flows, featuring support for ECS, S3, Secrets Manager, Batch Job, and Client Waiter. +`prefect-aws` makes it easy to leverage the capabilities of AWS in your flows, featuring support for ECS, S3, Secrets Manager, and Batch. Visit the full docs [here](https://PrefectHQ.github.io/prefect-aws). diff --git a/docs/img/favicon.ico b/docs/img/favicon.ico index c4b42158..159c4152 100644 Binary files a/docs/img/favicon.ico and b/docs/img/favicon.ico differ diff --git a/docs/img/prefect-logo-mark-solid-white-500.png b/docs/img/prefect-logo-mark-solid-white-500.png deleted file mode 100644 index f83aa6ef..00000000 Binary files a/docs/img/prefect-logo-mark-solid-white-500.png and /dev/null differ diff --git a/docs/img/prefect-logo-mark.png b/docs/img/prefect-logo-mark.png new file mode 100644 index 00000000..0d696821 Binary files /dev/null and b/docs/img/prefect-logo-mark.png differ diff --git a/docs/img/prefect-logo-white.png b/docs/img/prefect-logo-white.png deleted file mode 100644 index 50ca6139..00000000 Binary files a/docs/img/prefect-logo-white.png and /dev/null differ diff --git a/docs/index.md b/docs/index.md index fdb966ca..091b54e9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,26 +1,24 @@ -# Coordinate and incorporate AWS in your dataflow with `prefect-aws` +# `prefect-aws` ## Welcome! -The `prefect-aws` collection makes it easy to leverage the capabilities of AWS in your flows, featuring support for ECSTask, S3, Secrets Manager, Batch Job, and Client Waiter. +`prefect-aws` makes it easy to leverage the capabilities of AWS in your flows, featuring support for ECSTask, S3, Secrets Manager, Batch Job, and Client Waiter. ## Getting Started diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index 11a02095..662cca0c 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -1,9 +1,9 @@ /* theme */ :root > * { /* theme */ - --md-primary-fg-color: #115AF4; - --md-primary-fg-color--light: #115AF4; - --md-primary-fg-color--dark: #115AF4; + --md-primary-fg-color: #26272B; + --md-primary-fg-color--light: #26272B; + --md-primary-fg-color--dark: #26272B; } /* Table formatting */ @@ -72,7 +72,7 @@ to force column width */ /* dark mode slate theme */ /* dark mode code overrides */ [data-md-color-scheme="slate"] { - --md-code-bg-color: #252a33; + --md-code-bg-color: #1c1d20; --md-code-fg-color: #eee; --md-code-hl-color: #3b3d54; --md-code-hl-name-color: #eee; @@ -100,15 +100,15 @@ to force column width */ /* dark mode collection catalog overrides */ [data-md-color-scheme="slate"] .collection-item { - background-color: #3b3d54; + background-color: #26272B; } /* dark mode recipe collection overrides */ [data-md-color-scheme="slate"] .recipe-item { - background-color: #3b3d54; + background-color: #26272B; } /* dark mode API doc overrides */ [data-md-color-scheme="slate"] .prefect-table th { - background-color: #3b3d54; + background-color: #26272B; } \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 246fc31d..465f6407 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,7 +28,7 @@ theme: icon: repo: fontawesome/brands/github logo: - img/prefect-logo-mark-solid-white-500.png + img/prefect-logo-mark.png font: text: Inter code: Source Code Pro diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index 3f02b4b2..bf030590 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -5,7 +5,12 @@ from botocore import UNSIGNED from botocore.client import Config -from pydantic import BaseModel, Field, FilePath, root_validator, validator +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator +else: + from pydantic import BaseModel, Field, FilePath, root_validator, validator class AwsClientParameters(BaseModel): diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index f390faef..64f49efe 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -7,7 +7,12 @@ from mypy_boto3_s3 import S3Client from mypy_boto3_secretsmanager import SecretsManagerClient from prefect.blocks.abstract import CredentialsBlock -from pydantic import Field, SecretStr +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, SecretStr +else: + from pydantic import Field, SecretStr from prefect_aws.client_parameters import AwsClientParameters @@ -35,7 +40,7 @@ class AwsCredentials(CredentialsBlock): ``` """ # noqa E501 - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa _block_type_name = "AWS Credentials" _documentation_url = "https://prefecthq.github.io/prefect-aws/credentials/#prefect_aws.credentials.AwsCredentials" # noqa @@ -159,7 +164,7 @@ class MinIOCredentials(CredentialsBlock): ``` """ # noqa E501 - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/22vXcxsOrVeFrUwHfSoaeT/7607b876eb589a9028c8126e78f4c7b4/imageedit_7_2837870043.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/676cb17bcbdff601f97e0a02ff8bcb480e91ff40-250x250.png" # noqa _block_type_name = "MinIO Credentials" _description = ( "Block used to manage authentication with MinIO. Refer to the MinIO " diff --git a/prefect_aws/deployments/steps.py b/prefect_aws/deployments/steps.py index 805b4648..7525a5e2 100644 --- a/prefect_aws/deployments/steps.py +++ b/prefect_aws/deployments/steps.py @@ -91,14 +91,7 @@ def push_to_s3( ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - client = boto3.client( - "s3", **credentials, **client_parameters, config=Config(**advanced_config) - ) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -117,7 +110,9 @@ def push_to_s3( continue elif not local_file_path.is_dir(): remote_file_path = Path(folder) / local_file_path.relative_to(local_path) - client.upload_file(str(local_file_path), bucket, str(remote_file_path)) + s3.upload_file( + str(local_file_path), bucket, str(remote_file_path.as_posix()) + ) return { "bucket": bucket, @@ -172,14 +167,7 @@ def pull_from_s3( credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}" ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - - session = boto3.Session(**credentials) - s3 = session.client("s3", **client_parameters, config=Config(**advanced_config)) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -204,3 +192,51 @@ def pull_from_s3( "folder": folder, "directory": str(local_path), } + + +def get_s3_client( + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, +) -> dict: + if credentials is None: + credentials = {} + if client_parameters is None: + client_parameters = {} + + # Get credentials from credentials (regardless if block or not) + aws_access_key_id = credentials.get("aws_access_key_id", None) + aws_secret_access_key = credentials.get("aws_secret_access_key", None) + aws_session_token = credentials.get("aws_session_token", None) + + # Get remaining session info from credentials, or client_parameters + profile_name = credentials.get( + "profile_name", client_parameters.get("profile_name", None) + ) + region_name = credentials.get( + "region_name", client_parameters.get("region_name", None) + ) + + # Get additional info from client_parameters, otherwise credentials input (if block) + aws_client_parameters = credentials.get("aws_client_parameters", client_parameters) + api_version = aws_client_parameters.get("api_version", None) + endpoint_url = aws_client_parameters.get("endpoint_url", None) + use_ssl = aws_client_parameters.get("use_ssl", True) + verify = aws_client_parameters.get("verify", None) + config_params = aws_client_parameters.get("config", {}) + config = Config(**config_params) + + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=profile_name, + region_name=region_name, + ) + return session.client( + "s3", + api_version=api_version, + endpoint_url=endpoint_url, + use_ssl=use_ssl, + verify=verify, + config=config, + ) diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index a9534332..a6ebe206 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -121,7 +121,13 @@ from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible from prefect.utilities.dockerutils import get_prefect_image_name from prefect.utilities.pydantic import JsonPatch -from pydantic import Field, root_validator, validator +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, root_validator, validator +else: + from pydantic import Field, root_validator, validator + from slugify import slugify from typing_extensions import Literal, Self @@ -270,7 +276,7 @@ class ECSTask(Infrastructure): _block_type_slug = "ecs-task" _block_type_name = "ECS Task" - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa _description = "Run a command as an ECS task." # noqa _documentation_url = ( "https://prefecthq.github.io/prefect-aws/ecs/#prefect_aws.ecs.ECSTask" # noqa diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 17fe8870..11ca0438 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -14,7 +14,12 @@ from prefect.filesystems import WritableDeploymentStorage, WritableFileSystem from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible from prefect.utilities.filesystem import filter_files -from pydantic import Field +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field from prefect_aws import AwsCredentials, MinIOCredentials from prefect_aws.client_parameters import AwsClientParameters @@ -147,6 +152,160 @@ async def example_s3_upload_flow(): return key +@task +async def s3_copy( + source_path: str, + target_path: str, + source_bucket_name: str, + aws_credentials: AwsCredentials, + target_bucket_name: Optional[str] = None, + **copy_kwargs, +) -> str: + """Uses S3's internal + [CopyObject](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html) + to copy objects within or between buckets. To copy objects between buckets, the + credentials must have permission to read the source object and write to the target + object. If the credentials do not have those permissions, try using + `S3Bucket.stream_from`. + + Args: + source_path: The path to the object to copy. Can be a string or `Path`. + target_path: The path to copy the object to. Can be a string or `Path`. + source_bucket_name: The bucket to copy the object from. + aws_credentials: Credentials to use for authentication with AWS. + target_bucket_name: The bucket to copy the object to. If not provided, defaults + to `source_bucket`. + **copy_kwargs: Additional keyword arguments to pass to `S3Client.copy_object`. + + Returns: + The path that the object was copied to. Excludes the bucket name. + + Examples: + + Copy notes.txt from s3://my-bucket/my_folder/notes.txt to + s3://my-bucket/my_folder/notes_copy.txt. + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_copy + + aws_credentials = AwsCredentials.load("my-creds") + + @flow + async def example_copy_flow(): + await s3_copy( + source_path="my_folder/notes.txt", + target_path="my_folder/notes_copy.txt", + source_bucket_name="my-bucket", + aws_credentials=aws_credentials, + ) + + example_copy_flow() + ``` + + Copy notes.txt from s3://my-bucket/my_folder/notes.txt to + s3://other-bucket/notes_copy.txt. + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_copy + + aws_credentials = AwsCredentials.load("shared-creds") + + @flow + async def example_copy_flow(): + await s3_copy( + source_path="my_folder/notes.txt", + target_path="notes_copy.txt", + source_bucket_name="my-bucket", + aws_credentials=aws_credentials, + target_bucket_name="other-bucket", + ) + + example_copy_flow() + ``` + + """ + logger = get_run_logger() + + s3_client = aws_credentials.get_s3_client() + + target_bucket_name = target_bucket_name or source_bucket_name + + logger.info( + "Copying object from bucket %s with key %s to bucket %s with key %s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + s3_client.copy_object( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + **copy_kwargs, + ) + + return target_path + + +@task +async def s3_move( + source_path: str, + target_path: str, + source_bucket_name: str, + aws_credentials: AwsCredentials, + target_bucket_name: Optional[str] = None, +) -> str: + """ + Move an object from one S3 location to another. To move objects between buckets, + the credentials must have permission to read and delete the source object and write + to the target object. If the credentials do not have those permissions, this method + will raise an error. If the credentials have permission to read the source object + but not delete it, the object will be copied but not deleted. + + Args: + source_path: The path of the object to move + target_path: The path to move the object to + source_bucket_name: The name of the bucket containing the source object + aws_credentials: Credentials to use for authentication with AWS. + target_bucket_name: The bucket to copy the object to. If not provided, defaults + to `source_bucket`. + + Returns: + The path that the object was moved to. Excludes the bucket name. + """ + logger = get_run_logger() + + s3_client = aws_credentials.get_s3_client() + + # If target bucket is not provided, assume it's the same as the source bucket + target_bucket_name = target_bucket_name or source_bucket_name + + logger.info( + "Moving object from s3://%s/%s s3://%s/%s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + # Copy the object to the new location + s3_client.copy_object( + Bucket=target_bucket_name, + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Key=target_path, + ) + + # Delete the original object + s3_client.delete_object(Bucket=source_bucket_name, Key=source_path) + + return target_path + + def _list_objects_sync(page_iterator: PageIterator): """ Synchronous method to collect S3 objects into a list @@ -245,7 +404,7 @@ class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock for reading and writing objects. """ - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa _block_type_name = "S3 Bucket" _documentation_url = ( "https://prefecthq.github.io/prefect-aws/s3/#prefect_aws.s3.S3Bucket" # noqa @@ -795,7 +954,9 @@ async def stream_from( to_path: Optional[str] = None, **upload_kwargs: Dict[str, Any], ) -> str: - """Streams an object from another bucket to this bucket. + """Streams an object from another bucket to this bucket. Requires the + object to be downloaded and uploaded in chunks. If `self`'s credentials + allow for writes to the other bucket, try using `S3Bucket.copy_object`. Args: bucket: The bucket to stream from. @@ -1023,3 +1184,170 @@ async def upload_from_folder( ) return to_folder + + @sync_compatible + async def copy_object( + self, + from_path: Union[str, Path], + to_path: Union[str, Path], + to_bucket: Optional[Union["S3Bucket", str]] = None, + **copy_kwargs, + ) -> str: + """Uses S3's internal + [CopyObject](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html) + to copy objects within or between buckets. To copy objects between buckets, + `self`'s credentials must have permission to read the source object and write + to the target object. If the credentials do not have those permissions, try + using `S3Bucket.stream_from`. + + Args: + from_path: The path of the object to copy. + to_path: The path to copy the object to. + to_bucket: The bucket to copy to. Defaults to the current bucket. + **copy_kwargs: Additional keyword arguments to pass to + `S3Client.copy_object`. + + Returns: + The path that the object was copied to. Excludes the bucket name. + + Examples: + + Copy notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.copy_object("my_folder/notes.txt", "my_folder/notes_copy.txt") + ``` + + Copy notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt in + another bucket. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.copy_object( + "my_folder/notes.txt", + "my_folder/notes_copy.txt", + to_bucket="other-bucket" + ) + ``` + """ + s3_client = self.credentials.get_s3_client() + + source_path = self._resolve_path(Path(from_path).as_posix()) + target_path = self._resolve_path(Path(to_path).as_posix()) + + source_bucket_name = self.bucket_name + target_bucket_name = self.bucket_name + if isinstance(to_bucket, S3Bucket): + target_bucket_name = to_bucket.bucket_name + target_path = to_bucket._resolve_path(target_path) + elif isinstance(to_bucket, str): + target_bucket_name = to_bucket + elif to_bucket is not None: + raise TypeError( + "to_bucket must be a string or S3Bucket, not" + f" {type(target_bucket_name)}" + ) + + self.logger.info( + "Copying object from bucket %s with key %s to bucket %s with key %s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + s3_client.copy_object( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + **copy_kwargs, + ) + + return target_path + + @sync_compatible + async def move_object( + self, + from_path: Union[str, Path], + to_path: Union[str, Path], + to_bucket: Optional[Union["S3Bucket", str]] = None, + ) -> str: + """Uses S3's internal CopyObject and DeleteObject to move objects within or + between buckets. To move objects between buckets, `self`'s credentials must + have permission to read and delete the source object and write to the target + object. If the credentials do not have those permissions, this method will + raise an error. If the credentials have permission to read the source object + but not delete it, the object will be copied but not deleted. + + Args: + from_path: The path of the object to move. + to_path: The path to move the object to. + to_bucket: The bucket to move to. Defaults to the current bucket. + + Returns: + The path that the object was moved to. Excludes the bucket name. + + Examples: + + Move notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.move_object("my_folder/notes.txt", "my_folder/notes_copy.txt") + ``` + + Move notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt in + another bucket. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.move_object( + "my_folder/notes.txt", + "my_folder/notes_copy.txt", + to_bucket="other-bucket" + ) + ``` + """ + s3_client = self.credentials.get_s3_client() + + source_path = self._resolve_path(Path(from_path).as_posix()) + target_path = self._resolve_path(Path(to_path).as_posix()) + + source_bucket_name = self.bucket_name + target_bucket_name = self.bucket_name + if isinstance(to_bucket, S3Bucket): + target_bucket_name = to_bucket.bucket_name + target_path = to_bucket._resolve_path(target_path) + elif isinstance(to_bucket, str): + target_bucket_name = to_bucket + elif to_bucket is not None: + raise TypeError( + "to_bucket must be a string or S3Bucket, not" + f" {type(target_bucket_name)}" + ) + + self.logger.info( + "Moving object from s3://%s/%s to s3://%s/%s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + # If invalid, should error and prevent next operation + s3_client.copy( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + ) + s3_client.delete_object(Bucket=source_bucket_name, Key=source_path) + return target_path diff --git a/prefect_aws/secrets_manager.py b/prefect_aws/secrets_manager.py index 2892a0eb..c043f28c 100644 --- a/prefect_aws/secrets_manager.py +++ b/prefect_aws/secrets_manager.py @@ -5,7 +5,12 @@ from prefect import get_run_logger, task from prefect.blocks.abstract import SecretBlock from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible -from pydantic import Field +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field from prefect_aws import AwsCredentials @@ -363,7 +368,7 @@ class AwsSecret(SecretBlock): secret_name: The name of the secret. """ - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa _block_type_name = "AWS Secret" _documentation_url = "https://prefecthq.github.io/prefect-aws/secrets_manager/#prefect_aws.secrets_manager.AwsSecret" # noqa diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index f4d97bef..02c4117c 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -68,7 +68,13 @@ BaseWorker, BaseWorkerResult, ) -from pydantic import Field, root_validator +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, root_validator +else: + from pydantic import Field, root_validator + from slugify import slugify from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from typing_extensions import Literal @@ -542,7 +548,7 @@ class ECSWorker(BaseWorker): ) _display_name = "AWS Elastic Container Service" _documentation_url = "https://prefecthq.github.io/prefect-aws/ecs_worker/" - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa async def run( self, @@ -1343,10 +1349,10 @@ def _custom_network_configuration( + "Network configuration cannot be inferred." ) + subnet_ids = [subnet["SubnetId"] for subnet in subnets] + config_subnets = network_configuration.get("subnets", []) - if not all( - [conf_sn in sn.values() for conf_sn in config_subnets for sn in subnets] - ): + if not all(conf_sn in subnet_ids for conf_sn in config_subnets): raise ValueError( f"Subnets {config_subnets} not found within {vpc_message}." + "Please check that VPC is associated with supplied subnets." diff --git a/requirements-dev.txt b/requirements-dev.txt index 9003b759..bcbdc906 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,12 +9,15 @@ mkdocs-gen-files mkdocs-material mkdocstrings-python-legacy mock; python_version < '3.8' -moto >= 3.1.16 +# moto 4.2.5 broke something fairly deep in our test suite +# https://github.com/PrefectHQ/prefect-aws/issues/318 +moto >= 3.1.16, < 4.2.5 mypy pillow pre-commit pytest pytest-asyncio +pytest-cov pytest-lazy-fixture pytest-xdist types-boto3 >= 1.0.2 diff --git a/requirements.txt b/requirements.txt index 40b6007d..919ce567 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ boto3>=1.24.53 botocore>=1.27.53 mypy_boto3_s3>=1.24.94 mypy_boto3_secretsmanager>=1.26.49 -prefect>=2.10.11 +prefect>=2.13.5 tenacity>=8.0.0 \ No newline at end of file diff --git a/tests/deploments/test_steps.py b/tests/deploments/test_steps.py index 5e2d2efd..22608bd7 100644 --- a/tests/deploments/test_steps.py +++ b/tests/deploments/test_steps.py @@ -1,11 +1,14 @@ import os +import sys from pathlib import Path, PurePath, PurePosixPath +from unittest.mock import patch import boto3 import pytest from moto import mock_s3 -from prefect_aws.deployments.steps import pull_from_s3, push_to_s3 +from prefect_aws import AwsCredentials +from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3 @pytest.fixture @@ -40,6 +43,24 @@ def tmp_files(tmp_path: Path): return tmp_path +@pytest.fixture +def tmp_files_win(tmp_path: Path): + files = [ + "testfile1.txt", + "testfile2.txt", + "testfile3.txt", + r"testdir1\testfile4.txt", + r"testdir2\testfile5.txt", + ] + + for file in files: + filepath = tmp_path / file + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_text("Sample text") + + return tmp_path + + @pytest.fixture def mock_aws_credentials(monkeypatch): # Set mock environment variables for AWS credentials @@ -77,6 +98,29 @@ def test_push_to_s3(s3_setup, tmp_files, mock_aws_credentials): assert set(object_keys) == set(expected_keys) +@pytest.mark.skipif(sys.platform != "win32", reason="requires Windows") +def test_push_to_s3_as_posix(s3_setup, tmp_files_win, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + os.chdir(tmp_files_win) + + push_to_s3(bucket_name, folder) + + s3_objects = s3.list_objects_v2(Bucket=bucket_name) + object_keys = [item["Key"] for item in s3_objects["Contents"]] + + expected_keys = [ + f"{folder}/testfile1.txt", + f"{folder}/testfile2.txt", + f"{folder}/testfile3.txt", + f"{folder}/testdir1/testfile4.txt", + f"{folder}/testdir2/testfile5.txt", + ] + + assert set(object_keys) == set(expected_keys) + + def test_pull_from_s3(s3_setup, tmp_path, mock_aws_credentials): s3, bucket_name = s3_setup folder = "my-project" @@ -131,6 +175,98 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials): assert not (tmp_path / "empty2_copy").exists() +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8+") +def test_s3_session_with_params(): + with patch("boto3.Session") as mock_session: + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "profile_name": "foo", + "region_name": "us-weast-1", + "aws_client_parameters": { + "api_version": "v1", + "config": {"connect_timeout": 300}, + }, + } + ) + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + }, + client_parameters={ + "region_name": "us-west-1", + "config": {"signature_version": "s3v4"}, + }, + ) + creds_block = AwsCredentials( + aws_access_key_id="BlockKey", + aws_secret_access_key="BlockSecret", + aws_session_token="BlockToken", + profile_name="BlockProfile", + region_name="BlockRegion", + aws_client_parameters={ + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + "config": {"connect_timeout": 123}, + }, + ) + get_s3_client(credentials=creds_block.dict()) + all_calls = mock_session.mock_calls + assert len(all_calls) == 6 + assert all_calls[0].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": "foo", + "region_name": "us-weast-1", + } + assert all_calls[1].args[0] == "s3" + assert { + "api_version": "v1", + "endpoint_url": None, + "use_ssl": True, + "verify": None, + }.items() <= all_calls[1].kwargs.items() + assert all_calls[1].kwargs.get("config").connect_timeout == 300 + assert all_calls[1].kwargs.get("config").signature_version is None + assert all_calls[2].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": None, + "region_name": "us-west-1", + } + assert all_calls[3].args[0] == "s3" + assert { + "api_version": None, + "endpoint_url": None, + "use_ssl": True, + "verify": None, + }.items() <= all_calls[3].kwargs.items() + assert all_calls[3].kwargs.get("config").connect_timeout == 60 + assert all_calls[3].kwargs.get("config").signature_version == "s3v4" + assert all_calls[4].kwargs == { + "aws_access_key_id": "BlockKey", + "aws_secret_access_key": creds_block.aws_secret_access_key, + "aws_session_token": "BlockToken", + "profile_name": "BlockProfile", + "region_name": "BlockRegion", + } + assert all_calls[5].args[0] == "s3" + assert { + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + }.items() <= all_calls[5].kwargs.items() + assert all_calls[5].kwargs.get("config").connect_timeout == 123 + assert all_calls[5].kwargs.get("config").signature_version is None + + def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): s3, bucket_name = s3_setup folder = "my-project" diff --git a/tests/test_ecs.py b/tests/test_ecs.py index 9282a859..cf18bfe4 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -16,7 +16,12 @@ from prefect.server.schemas.core import Deployment, Flow, FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.dockerutils import get_prefect_image_name -from pydantic import ValidationError +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import ValidationError +else: + from pydantic import ValidationError from prefect_aws.ecs import ( ECS_DEFAULT_CPU, @@ -1363,7 +1368,7 @@ async def test_latest_task_definition_not_used_if_inequal( # {"execution_role_arn": "test"}, # {"launch_type": "EXTERNAL"}, ], - ids=lambda item: str(set(item.keys())), + ids=lambda item: str(sorted(list(set(item.keys())))), ) async def test_latest_task_definition_with_overrides_that_do_not_require_copy( aws_credentials, overrides, launch_type @@ -1531,7 +1536,7 @@ async def test_task_definition_arn_with_overrides_requiring_copy_shows_diff( # from the base task definition {"env": {"FOO": None}}, ], - ids=lambda item: str(set(item.keys())), + ids=lambda item: str(sorted(list(set(item.keys())))), ) async def test_task_definition_arn_with_overrides_that_do_not_require_copy( aws_credentials, overrides diff --git a/tests/test_s3.py b/tests/test_s3.py index c0159e25..89a39f7d 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -12,7 +12,14 @@ from prefect_aws import AwsCredentials, MinIOCredentials from prefect_aws.client_parameters import AwsClientParameters -from prefect_aws.s3 import S3Bucket, s3_download, s3_list_objects, s3_upload +from prefect_aws.s3 import ( + S3Bucket, + s3_copy, + s3_download, + s3_list_objects, + s3_move, + s3_upload, +) aws_clients = [ (lazy_fixture("aws_client_parameters_custom_endpoint")), @@ -47,6 +54,18 @@ def bucket(s3_mock, request): return bucket +@pytest.fixture +def bucket_2(s3_mock, request): + s3 = boto3.resource("s3") + bucket = s3.Bucket("bucket_2") + marker = request.node.get_closest_marker("is_public", None) + if marker and marker.args[0]: + bucket.create(ACL="public-read") + else: + bucket.create() + return bucket + + @pytest.fixture def object(bucket, tmp_path): file = tmp_path / "object.txt" @@ -205,6 +224,151 @@ async def test_flow(): assert output == b"NEW OBJECT" +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_copy(object, bucket, bucket_2, aws_credentials): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test cross-bucket copy + await s3_copy( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + target_bucket_name="bucket_2", + ) + + # Test within-bucket copy + await s3_copy( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + ) + + await test_flow() + assert read(bucket_2, "subfolder/new_object") == b"TEST" + assert read(bucket, "subfolder/new_object") == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_move(object, bucket, bucket_2, aws_credentials): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test within-bucket move + await s3_move( + source_path="object", + target_path="subfolder/object_copy", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + ) + + # Test cross-bucket move + await s3_move( + source_path="subfolder/object_copy", + target_path="object_copy_2", + source_bucket_name="bucket", + target_bucket_name="bucket_2", + aws_credentials=aws_credentials, + ) + + await test_flow() + + assert read(bucket_2, "object_copy_2") == b"TEST" + + with pytest.raises(ClientError): + read(bucket, "object") + + with pytest.raises(ClientError): + read(bucket, "subfolder/object_copy") + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_move_object_to_nonexistent_bucket_fails( + object, + bucket, + aws_credentials, +): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test cross-bucket move + await s3_move( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + target_bucket_name="nonexistent-bucket", + ) + + with pytest.raises(ClientError): + await test_flow() + + assert read(bucket, "object") == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_move_object_fail_cases( + object, + bucket, + aws_credentials, +): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow( + source_path, target_path, source_bucket_name, target_bucket_name + ): + # Test cross-bucket move + await s3_move( + source_path=source_path, + target_path=target_path, + source_bucket_name=source_bucket_name, + aws_credentials=aws_credentials, + target_bucket_name=target_bucket_name, + ) + + # Move to non-existent bucket + with pytest.raises(ClientError): + await test_flow( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + target_bucket_name="nonexistent-bucket", + ) + assert read(bucket, "object") == b"TEST" + + # Move onto self + with pytest.raises(ClientError): + await test_flow( + source_path="object", + target_path="object", + source_bucket_name="bucket", + target_bucket_name="bucket", + ) + assert read(bucket, "object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) async def test_s3_list_objects( object, client_parameters, object_in_folder, aws_credentials @@ -623,9 +787,9 @@ def s3_bucket_empty(self, credentials, bucket): return _s3_bucket @pytest.fixture - def s3_bucket_2_empty(self, credentials, bucket): + def s3_bucket_2_empty(self, credentials, bucket_2): _s3_bucket = S3Bucket( - bucket_name="bucket", + bucket_name="bucket_2", credentials=credentials, bucket_folder="subfolder", ) @@ -811,3 +975,71 @@ def test_upload_from_folder( break else: raise AssertionError("Files did upload") + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_copy_object( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + ): + s3_bucket_with_object.copy_object("object", "object_copy_1") + assert s3_bucket_with_object.read_path("object_copy_1") == b"TEST" + + s3_bucket_with_object.copy_object("object", "folder/object_copy_2") + assert s3_bucket_with_object.read_path("folder/object_copy_2") == b"TEST" + + # S3Bucket for second bucket has a basepath + s3_bucket_with_object.copy_object( + "object", + s3_bucket_2_empty._resolve_path("object_copy_3"), + to_bucket="bucket_2", + ) + assert s3_bucket_2_empty.read_path("object_copy_3") == b"TEST" + + s3_bucket_with_object.copy_object("object", "object_copy_4", s3_bucket_2_empty) + assert s3_bucket_2_empty.read_path("object_copy_4") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_within_bucket( + self, + s3_bucket_with_object: S3Bucket, + ): + s3_bucket_with_object.move_object("object", "object_copy_1") + assert s3_bucket_with_object.read_path("object_copy_1") == b"TEST" + + with pytest.raises(ClientError): + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_to_nonexistent_bucket_fails( + self, + s3_bucket_with_object: S3Bucket, + ): + with pytest.raises(ClientError): + s3_bucket_with_object.move_object( + "object", "object_copy_1", to_bucket="nonexistent-bucket" + ) + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_onto_itself_fails( + self, + s3_bucket_with_object: S3Bucket, + ): + with pytest.raises(ClientError): + s3_bucket_with_object.move_object("object", "object") + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_between_buckets( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + ): + s3_bucket_with_object.move_object( + "object", "object_copy_1", to_bucket=s3_bucket_2_empty + ) + assert s3_bucket_2_empty.read_path("object_copy_1") == b"TEST" + + with pytest.raises(ClientError): + assert s3_bucket_with_object.read_path("object") == b"TEST" diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index ef57372b..b6a39b35 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -12,7 +12,13 @@ from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread -from pydantic import ValidationError +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import ValidationError +else: + from pydantic import ValidationError + from tenacity import RetryError from prefect_aws.workers.ecs_worker import ( @@ -886,13 +892,62 @@ async def test_network_config_from_vpc_id( @pytest.mark.usefixtures("ecs_mocks") -async def test_network_config_from_custom_settings( +async def test_network_config_1_subnet_in_custom_settings_1_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_1_sn_in_custom_settings_many_in_vpc( aws_credentials: AwsCredentials, flow_run: FlowRun ): session = aws_credentials.get_boto3_session() ec2_resource = session.resource("ec2") vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.3.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.4.0/24", VpcId=vpc.id) + security_group = ec2_resource.create_security_group( GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id ) @@ -931,6 +986,58 @@ async def test_network_config_from_custom_settings( } +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_many_subnet_in_custom_settings_many_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnets = [ + ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.33.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.44.0/24", VpcId=vpc.id), + ] + subnet_ids = [subnet.id for subnet in subnets] + + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + @pytest.mark.usefixtures("ecs_mocks") async def test_network_config_from_custom_settings_invalid_subnet( aws_credentials: AwsCredentials, flow_run: FlowRun @@ -972,6 +1079,48 @@ async def test_network_config_from_custom_settings_invalid_subnet( await run_then_stop_task(worker, configuration, flow_run) +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_subnets( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + invalid_subnet_id = "subnet-3bf19de7" + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [invalid_subnet_id, subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + with pytest.raises( + ValueError, + match=( + rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC" + f" with ID {vpc.id}.Please check that VPC is associated with supplied" + " subnets." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + await run_then_stop_task(worker, configuration, flow_run) + + @pytest.mark.usefixtures("ecs_mocks") async def test_network_config_configure_network_requires_vpc_id( aws_credentials: AwsCredentials, flow_run: FlowRun @@ -1438,7 +1587,7 @@ async def test_worker_task_definition_cache_miss_on_deregistered( # {"execution_role_arn": "test"}, # {"launch_type": "EXTERNAL"}, ], - ids=lambda item: str(set(item.keys())), + ids=lambda item: str(sorted(list(set(item.keys())))), ) async def test_worker_task_definition_cache_hit_on_config_changes( aws_credentials: AwsCredentials,