diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index ec0b47af..c65e594a 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -246,6 +246,16 @@ def mask_api_key(task_run_request): ) +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacityProvider: str + weight: int + base: int + + class ECSJobConfiguration(BaseJobConfiguration): """ Job configuration for an ECS worker. @@ -268,6 +278,7 @@ class ECSJobConfiguration(BaseJobConfiguration): auto_deregister_task_definition: bool = Field(default=False) vpc_id: Optional[str] = Field(default=None) container_name: Optional[str] = Field(default=None) + cluster: Optional[str] = Field(default=None) match_latest_revision_in_family: bool = Field(default=False) @@ -368,16 +379,6 @@ def network_configuration_requires_vpc_id(cls, values: dict) -> dict: return values -class CapacityProvider(BaseModel): - """ - The capacity provider strategy to use when running the task. - """ - - capacityProvider: str - weight: int - base: int - - class ECSVariables(BaseVariables): """ Variables for templating an ECS job. @@ -437,7 +438,7 @@ class ECSVariables(BaseVariables): ) ) capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( - default_factory=List[CapacityProvider], + default_factory=list, description=( "The capacity provider strategy to use when running the task. This is only" "If a capacityProviderStrategy is specified, we will omit the launchType" @@ -1467,10 +1468,7 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn - capacityProviderStrategy = ( - task_run_request.get("capacityProviderStrategy") - or configuration.capacity_provider_strategy - ) + capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") if capacityProviderStrategy: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa @@ -1485,11 +1483,9 @@ def _prepare_task_run_request( task_run_request.pop("launchType", None) # A capacity provider strategy is required for FARGATE SPOT - task_run_request.setdefault( - "capacityProviderStrategy", - [{"capacityProvider": "FARGATE_SPOT", "weight": 1}], - ) - + task_run_request["capacityProviderStrategy"] = [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] overrides = task_run_request.get("overrides", {}) container_overrides = overrides.get("containerOverrides", []) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index a6112773..386ebf10 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -8,7 +8,7 @@ import anyio import pytest import yaml -from moto import mock_autoscaling, mock_ec2, mock_ecs, mock_logs +from moto import mock_ec2, mock_ecs, mock_logs from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread @@ -275,7 +275,7 @@ def ecs_mocks( aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code ): with mock_ecs() as ecs: - with mock_ec2(), mock_autoscaling(): + with mock_ec2(): with mock_logs(): session = aws_credentials.get_boto3_session() @@ -506,6 +506,7 @@ async def test_launch_types( # Instead, it requires a capacity provider strategy but this is not supported # by moto and is not present on the task even when provided so we assert on the # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ {"capacityProvider": "FARGATE_SPOT", "weight": 1} ]