Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

seperates task and container override #409

Closed
wants to merge 9 commits into from
53 changes: 41 additions & 12 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@
- name: "{{ container_name }}"
command: "{{ command }}"
environment: "{{ env }}"
cpu: "{{ cpu }}"
memory: "{{ memory }}"
cpu: "{{ container_cpu }}"
memory: "{{ container_memory }}"
cpu: "{{ cpu }}"
memory: "{{ memory }}"
taskRoleArn: "{{ task_role_arn }}"
Expand Down Expand Up @@ -267,6 +267,8 @@ class ECSJobConfiguration(BaseJobConfiguration):
auto_deregister_task_definition: bool = Field(default=False)
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)
container_cpu: Optional[int] = Field(default=None)
container_memory: Optional[int] = Field(default=None)
cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

Expand Down Expand Up @@ -450,13 +452,38 @@ class ECSVariables(BaseVariables):
f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition."
),
)
container_cpu: int = Field(
title="CPU",
default=None,
description=(
"The amount of CPU to provide to the Prefect container. Valid amounts are"
" specified in the AWS documentation. If not provided, the value supplied"
" to `cpu` for the task will be used. If neither are provided,"
f" {ECS_DEFAULT_CPU} will be used unless present on the task definition."
),
)
container_memory: int = Field(
default=None,
description=(
"The amount of memory to provide to the Prefect container. Valid amounts"
" are specified in the AWS documentation. If not provided, a default value"
f" of {ECS_DEFAULT_MEMORY} will be used unless present on the task"
" definition. Only use these variables when your ECS tasks have multiple"
" containers. The total CPU and memory for all of a task's containers"
" cannot exceed the CPU and memory set for the task."
),
)

container_name: str = Field(
default=None,
description=(
"The name of the container flow run orchestration will occur in. If not "
f"specified, a default value of {ECS_DEFAULT_CONTAINER_NAME} will be used "
"and if that is not found in the task definition the first container will "
"be used."
"The amount of memory to provide to the Prefect container. Valid amounts"
" are specified in the AWS documentation. If not provided, the value"
" supplied to `cpu` for the task will be used. If neither are provided,"
f" {ECS_DEFAULT_CPU} will be used unless present on the task definition."
" Only use these variables when your ECS tasks have multiple containers."
" The total CPU and memory for all of a task's containers cannot exceed"
" the CPU and memory set for the task."
),
)
task_role_arn: str = Field(
Expand Down Expand Up @@ -1305,17 +1332,19 @@ def _prepare_task_definition(
task_definition, flow_run
)
# CPU and memory are required in some cases, retrieve the value to use
cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU
memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY
task_cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU
task_memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY
container_cpu = configuration.container_cpu or task_cpu
container_memory = configuration.container_memory or task_memory

launch_type = configuration.task_run_request.get(
"launchType", ECS_DEFAULT_LAUNCH_TYPE
)

if launch_type == "FARGATE" or launch_type == "FARGATE_SPOT":
# Task level memory and cpu are required when using fargate
task_definition["cpu"] = str(cpu)
task_definition["memory"] = str(memory)
task_definition["cpu"] = str(task_cpu)
task_definition["memory"] = str(task_memory)

# The FARGATE compatibility is required if it will be used as as launch type
requires_compatibilities = task_definition.setdefault(
Expand All @@ -1330,8 +1359,8 @@ def _prepare_task_definition(

elif launch_type == "EC2":
# Container level memory and cpu are required when using ec2
container.setdefault("cpu", cpu)
container.setdefault("memory", memory)
container.setdefault("cpu", container_cpu)
container.setdefault("memory", container_memory)

# Ensure set values are cast to integers
container["cpu"] = int(container["cpu"])
Expand Down
18 changes: 13 additions & 5 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,23 @@ async def test_launch_types(
@pytest.mark.parametrize(
"cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)]
)
@pytest.mark.parametrize("container_cpu,container_memory", [(None, None), (1024, 2048)])
async def test_cpu_and_memory(
aws_credentials: AwsCredentials,
launch_type: str,
flow_run: FlowRun,
cpu: int,
memory: int,
container_cpu: int,
container_memory: int,
):
configuration = await construct_configuration(
aws_credentials=aws_credentials, launch_type=launch_type, cpu=cpu, memory=memory
aws_credentials=aws_credentials,
launch_type=launch_type,
cpu=cpu,
memory=memory,
container_cpu=container_cpu,
container_memory=container_memory,
)

session = aws_credentials.get_boto3_session()
Expand All @@ -553,8 +561,8 @@ async def test_cpu_and_memory(

if launch_type == "EC2":
# EC2 requires CPU and memory to be defined at the container level
assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU
assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY
assert container_definition["cpu"] == container_cpu or ECS_DEFAULT_CPU
assert container_definition["memory"] == container_memory or ECS_DEFAULT_MEMORY
else:
# Fargate requires CPU and memory to be defined at the task definition level
assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU)
Expand All @@ -564,8 +572,8 @@ async def test_cpu_and_memory(
assert overrides.get("cpu") == (str(cpu) if cpu else None)
assert overrides.get("memory") == (str(memory) if memory else None)
# And as overrides for the Prefect container
assert container_overrides.get("cpu") == cpu
assert container_overrides.get("memory") == memory
assert container_overrides.get("cpu") == container_cpu
assert container_overrides.get("memory") == container_memory


@pytest.mark.usefixtures("ecs_mocks")
Expand Down
Loading