Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
sets fargate spot capacity provider
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Apr 10, 2024
1 parent bbdae80 commit 9107e27
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
36 changes: 16 additions & 20 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ def mask_api_key(task_run_request):
)


class CapacityProvider(BaseModel):
"""
The capacity provider strategy to use when running the task.
"""

capacityProvider: str
weight: int
base: int


class ECSJobConfiguration(BaseJobConfiguration):
"""
Job configuration for an ECS worker.
Expand All @@ -268,6 +278,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
auto_deregister_task_definition: bool = Field(default=False)
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)

cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

Expand Down Expand Up @@ -368,16 +379,6 @@ def network_configuration_requires_vpc_id(cls, values: dict) -> dict:
return values


class CapacityProvider(BaseModel):
"""
The capacity provider strategy to use when running the task.
"""

capacityProvider: str
weight: int
base: int


class ECSVariables(BaseVariables):
"""
Variables for templating an ECS job.
Expand Down Expand Up @@ -437,7 +438,7 @@ class ECSVariables(BaseVariables):
)
)
capacity_provider_strategy: Optional[List[CapacityProvider]] = Field(
default_factory=List[CapacityProvider],
default_factory=list,
description=(
"The capacity provider strategy to use when running the task. This is only"
"If a capacityProviderStrategy is specified, we will omit the launchType"
Expand Down Expand Up @@ -1467,10 +1468,7 @@ def _prepare_task_run_request(

task_run_request.setdefault("taskDefinition", task_definition_arn)
assert task_run_request["taskDefinition"] == task_definition_arn
capacityProviderStrategy = (
task_run_request.get("capacityProviderStrategy")
or configuration.capacity_provider_strategy
)
capacityProviderStrategy = task_run_request.get("capacityProviderStrategy")

if capacityProviderStrategy:
# Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa
Expand All @@ -1485,11 +1483,9 @@ def _prepare_task_run_request(
task_run_request.pop("launchType", None)

# A capacity provider strategy is required for FARGATE SPOT
task_run_request.setdefault(
"capacityProviderStrategy",
[{"capacityProvider": "FARGATE_SPOT", "weight": 1}],
)

task_run_request["capacityProviderStrategy"] = [
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
]
overrides = task_run_request.get("overrides", {})
container_overrides = overrides.get("containerOverrides", [])

Expand Down
5 changes: 3 additions & 2 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import anyio
import pytest
import yaml
from moto import mock_autoscaling, mock_ec2, mock_ecs, mock_logs
from moto import mock_ec2, mock_ecs, mock_logs
from moto.ec2.utils import generate_instance_identity_document
from prefect.server.schemas.core import FlowRun
from prefect.utilities.asyncutils import run_sync_in_worker_thread
Expand Down Expand Up @@ -275,7 +275,7 @@ def ecs_mocks(
aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code
):
with mock_ecs() as ecs:
with mock_ec2(), mock_autoscaling():
with mock_ec2():
with mock_logs():
session = aws_credentials.get_boto3_session()

Expand Down Expand Up @@ -506,6 +506,7 @@ async def test_launch_types(
# Instead, it requires a capacity provider strategy but this is not supported
# by moto and is not present on the task even when provided so we assert on the
# mock call to ensure it is sent

assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
]
Expand Down

0 comments on commit 9107e27

Please sign in to comment.