Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
seperates task and container override
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Apr 5, 2024
1 parent e6b15c2 commit 4f12864
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 23 deletions.
56 changes: 41 additions & 15 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@
containerDefinitions:
- image: "{{ image }}"
name: "{{ container_name }}"
cpu: "{{ cpu }}"
cpu: "{{ task_cpu }}"
family: "{{ family }}"
memory: "{{ memory }}"
memory: "{{ task_memory }}"
executionRoleArn: "{{ execution_role_arn }}"
"""

Expand All @@ -119,10 +119,10 @@
- name: "{{ container_name }}"
command: "{{ command }}"
environment: "{{ env }}"
cpu: "{{ cpu }}"
memory: "{{ memory }}"
cpu: "{{ cpu }}"
memory: "{{ memory }}"
cpu: "{{ container_cpu }}"
memory: "{{ container_memory }}"
cpu: "{{ task_cpu }}"
memory: "{{ task_memory }}"
taskRoleArn: "{{ task_role_arn }}"
tags: "{{ labels }}"
taskDefinition: "{{ task_definition_arn }}"
Expand Down Expand Up @@ -267,6 +267,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
auto_deregister_task_definition: bool = Field(default=False)
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)

cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

Expand Down Expand Up @@ -433,7 +434,7 @@ class ECSVariables(BaseVariables):
"defaults to a Prefect base image matching your local versions."
),
)
cpu: int = Field(
task_cpu: int = Field(
title="CPU",
default=None,
description=(
Expand All @@ -442,14 +443,33 @@ class ECSVariables(BaseVariables):
f"{ECS_DEFAULT_CPU} will be used unless present on the task definition."
),
)
memory: int = Field(
task_memory: int = Field(
default=None,
description=(
"The amount of memory to provide to the ECS task. Valid amounts are "
"specified in the AWS documentation. If not provided, a default value of "
f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition."
),
)
container_cpu: int = Field(
title="CPU",
default=None,
description=(
"The amount of CPU to provide to the Prefect container. Valid amounts are "
"specified in the AWS documentation. If not provided, a default value of "
f"{ECS_DEFAULT_CPU} will be used unless present on the task definition."
),
)
container_memory: int = Field(
default=None,
description=(
"The amount of memory to provide to the Prefect container. Valid amounts"
" are specified in the AWS documentation. If not provided, a default value"
f" of {ECS_DEFAULT_MEMORY} will be used unless present on the task"
" definition."
),
)

container_name: str = Field(
default=None,
description=(
Expand Down Expand Up @@ -1305,17 +1325,23 @@ def _prepare_task_definition(
task_definition, flow_run
)
# CPU and memory are required in some cases, retrieve the value to use
cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU
memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY

task_cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU
task_memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY
container_cpu = configuration.task_run_request["overrides"][
"containerOverrides"
][0].get("cpu")
container_memory = configuration.task_run_request["overrides"][
"containerOverrides"
][0].get("cpu")
launch_type = configuration.task_run_request.get(
"launchType", ECS_DEFAULT_LAUNCH_TYPE
)
print(task_definition)

if launch_type == "FARGATE" or launch_type == "FARGATE_SPOT":
# Task level memory and cpu are required when using fargate
task_definition["cpu"] = str(cpu)
task_definition["memory"] = str(memory)
task_definition["cpu"] = str(task_cpu)
task_definition["memory"] = str(task_memory)

# The FARGATE compatibility is required if it will be used as as launch type
requires_compatibilities = task_definition.setdefault(
Expand All @@ -1330,8 +1356,8 @@ def _prepare_task_definition(

elif launch_type == "EC2":
# Container level memory and cpu are required when using ec2
container.setdefault("cpu", cpu)
container.setdefault("memory", memory)
container.setdefault("cpu", container_cpu or ECS_DEFAULT_CPU)
container.setdefault("memory", container_memory or ECS_DEFAULT_MEMORY)

# Ensure set values are cast to integers
container["cpu"] = int(container["cpu"])
Expand Down
38 changes: 30 additions & 8 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,22 @@ def ecs_mocks(
# NOTE: Even when using FARGATE, moto requires container instances to be
# registered. This differs from AWS behavior.
add_ec2_instance_to_ecs_cluster(session, "default")
ec2_client = session.client("ec2")
ec2_resource = session.resource("ec2")

images = ec2_client.describe_images()
image_id = images["Images"][0]["ImageId"]

test_instance = ec2_resource.create_instances(
ImageId=image_id, MinCount=1, MaxCount=1
)[0]

session.client("ecs").register_container_instance(
cluster="default",
instanceIdentityDocument=json.dumps(
generate_instance_identity_document(test_instance)
),
)

yield ecs

Expand Down Expand Up @@ -519,18 +535,24 @@ async def test_launch_types(

@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"])
@pytest.mark.parametrize(
"cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)]
)
@pytest.mark.parametrize("cpu,memory", [(2048, 4096)])
@pytest.mark.parametrize("container_cpu,container_memory", [(1024, 2048)])
async def test_cpu_and_memory(
aws_credentials: AwsCredentials,
launch_type: str,
flow_run: FlowRun,
cpu: int,
memory: int,
container_cpu: int,
container_memory: int,
):
configuration = await construct_configuration(
aws_credentials=aws_credentials, launch_type=launch_type, cpu=cpu, memory=memory
aws_credentials=aws_credentials,
launch_type=launch_type,
task_cpu=cpu,
task_memory=memory,
container_cpu=container_cpu,
container_memory=container_memory,
)

session = aws_credentials.get_boto3_session()
Expand All @@ -553,8 +575,8 @@ async def test_cpu_and_memory(

if launch_type == "EC2":
# EC2 requires CPU and memory to be defined at the container level
assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU
assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY
assert container_definition["cpu"] == container_cpu or ECS_DEFAULT_CPU
assert container_definition["memory"] == container_memory or ECS_DEFAULT_MEMORY
else:
# Fargate requires CPU and memory to be defined at the task definition level
assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU)
Expand All @@ -564,8 +586,8 @@ async def test_cpu_and_memory(
assert overrides.get("cpu") == (str(cpu) if cpu else None)
assert overrides.get("memory") == (str(memory) if memory else None)
# And as overrides for the Prefect container
assert container_overrides.get("cpu") == cpu
assert container_overrides.get("memory") == memory
assert container_overrides.get("cpu") == container_cpu
assert container_overrides.get("memory") == container_memory


@pytest.mark.usefixtures("ecs_mocks")
Expand Down

0 comments on commit 4f12864

Please sign in to comment.