Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support for capacity provider #407

Merged
merged 12 commits into from
Apr 11, 2024
42 changes: 34 additions & 8 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
from pydantic import VERSION as PYDANTIC_VERSION

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import Field, root_validator
from pydantic.v1 import BaseModel, Field, root_validator
else:
from pydantic import Field, root_validator
from pydantic import Field, root_validator, BaseModel

from slugify import slugify
from tenacity import retry, stop_after_attempt, wait_fixed, wait_random
Expand Down Expand Up @@ -126,6 +126,7 @@
taskRoleArn: "{{ task_role_arn }}"
tags: "{{ labels }}"
taskDefinition: "{{ task_definition_arn }}"
capacityProviderStrategy: "{{ capacity_provider_strategy }}"
"""

# Create task run retry settings
Expand Down Expand Up @@ -245,6 +246,16 @@ def mask_api_key(task_run_request):
)


class CapacityProvider(BaseModel):
"""
The capacity provider strategy to use when running the task.
"""

capacityProvider: str
weight: int
base: int


class ECSJobConfiguration(BaseJobConfiguration):
"""
Job configuration for an ECS worker.
Expand All @@ -267,6 +278,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
auto_deregister_task_definition: bool = Field(default=False)
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

Expand Down Expand Up @@ -425,6 +437,13 @@ class ECSVariables(BaseVariables):
),
)
)
capacity_provider_strategy: Optional[List[CapacityProvider]] = Field(
default_factory=list,
description=(
"The capacity provider strategy to use when running the task. This is only"
"If a capacityProviderStrategy is specified, we will omit the launchType"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The capacity provider strategy to use when running the task. "
"If a capacity provider strategy is specified, the selected launch type will be ignored."

),
)
image: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -1449,17 +1468,24 @@ def _prepare_task_run_request(

task_run_request.setdefault("taskDefinition", task_definition_arn)
assert task_run_request["taskDefinition"] == task_definition_arn
capacityProviderStrategy = task_run_request.get("capacityProviderStrategy")

if capacityProviderStrategy:
# Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa
self._logger.warning(
"Removing launchType from task run request. Due to finding"
" capacityProviderStrategy in the request."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Found capacityProviderStrategy. "
"Removing launchType from task run request."

)
task_run_request.pop("launchType", None)

if task_run_request.get("launchType") == "FARGATE_SPOT":
elif task_run_request.get("launchType") == "FARGATE_SPOT":
# Should not be provided at all for FARGATE SPOT
task_run_request.pop("launchType", None)

# A capacity provider strategy is required for FARGATE SPOT
task_run_request.setdefault(
"capacityProviderStrategy",
[{"capacityProvider": "FARGATE_SPOT", "weight": 1}],
)

task_run_request["capacityProviderStrategy"] = [
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
]
overrides = task_run_request.get("overrides", {})
container_overrides = overrides.get("containerOverrides", [])

Expand Down
36 changes: 36 additions & 0 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ async def test_launch_types(
# Instead, it requires a capacity provider strategy but this is not supported
# by moto and is not present on the task even when provided so we assert on the
# mock call to ensure it is sent

assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
]
Expand Down Expand Up @@ -2050,6 +2051,41 @@ async def test_user_defined_environment_variables_in_task_definition_template(
]


@pytest.mark.usefixtures("ecs_mocks")
async def test_user_defined_capacity_provider_strategy(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
configuration = await construct_configuration(
aws_credentials=aws_credentials,
capacity_provider_strategy=[
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"}
],
)
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")

async with ECSWorker(work_pool_name="test") as worker:
# Capture the task run call because moto does not track
# 'capacityProviderStrategy'
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
assert not task.get("launchType")
# Instead, it requires a capacity provider strategy but this is not supported
# by moto and is not present on the task even when provided so we assert on the
# mock call to ensure it is sent
assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"},
]


@pytest.mark.usefixtures("ecs_mocks")
async def test_user_defined_environment_variables_in_task_run_request_template(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down
Loading