diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index b3b54495..ec0b47af 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -126,6 +126,7 @@ taskRoleArn: "{{ task_role_arn }}" tags: "{{ labels }}" taskDefinition: "{{ task_definition_arn }}" +capacityProviderStrategy: "{{ capacity_provider_strategy }}" """ # Create task run retry settings @@ -372,7 +373,7 @@ class CapacityProvider(BaseModel): The capacity provider strategy to use when running the task. """ - capacity_provider: str + capacityProvider: str weight: int base: int @@ -436,7 +437,7 @@ class ECSVariables(BaseVariables): ) ) capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( - default=None, + default_factory=List[CapacityProvider], description=( "The capacity provider strategy to use when running the task. This is only" "If a capacityProviderStrategy is specified, we will omit the launchType" @@ -1466,8 +1467,12 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn + capacityProviderStrategy = ( + task_run_request.get("capacityProviderStrategy") + or configuration.capacity_provider_strategy + ) - if "capacityProviderStrategy" in task_run_request: + if capacityProviderStrategy: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa self._logger.warning( "Removing launchType from task run request. Due to finding" diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 9db3882e..a6112773 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -2026,9 +2026,6 @@ async def test_user_defined_capacity_provider_strategy( {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} ], ) - - assert "launchType" not in configuration.task_run_request - session = aws_credentials.get_boto3_session() ecs_client = session.client("ecs")