Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Generates family field (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Mar 28, 2024
1 parent 1b25b0a commit 219808c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
28 changes: 21 additions & 7 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def _create_task_and_wait_for_start(

if not task_definition_arn:
task_definition = self._prepare_task_definition(
configuration, region=ecs_client.meta.region_name
configuration, region=ecs_client.meta.region_name, flow_run=flow_run
)
(
task_definition_arn,
Expand Down Expand Up @@ -1205,10 +1205,28 @@ def _watch_task_run(
)
time.sleep(configuration.task_watch_poll_interval)

def _get_or_generate_family(self, task_definition: dict, flow_run: FlowRun) -> str:
"""
Gets or generate a family for the task definition.
"""
family = task_definition.get("family")
if not family:
assert self._work_pool_name and flow_run.deployment_id
family = (
f"{ECS_DEFAULT_FAMILY}_{self._work_pool_name}_{flow_run.deployment_id}"
)
slugify(
family,
max_length=255,
regex_pattern=r"[^a-zA-Z0-9-_]+",
)
return family

def _prepare_task_definition(
self,
configuration: ECSJobConfiguration,
region: str,
flow_run: FlowRun,
) -> dict:
"""
Prepare a task definition by inferring any defaults and merging overrides.
Expand Down Expand Up @@ -1269,13 +1287,9 @@ def _prepare_task_definition(
},
}

family = task_definition.get("family") or ECS_DEFAULT_FAMILY
task_definition["family"] = slugify(
family,
max_length=255,
regex_pattern=r"[^a-zA-Z0-9-_]+",
task_definition["family"] = self._get_or_generate_family(
task_definition, flow_run
)

# CPU and memory are required in some cases, retrieve the value to use
cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU
memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY
Expand Down
26 changes: 26 additions & 0 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_TASK_DEFINITION_CACHE,
ECS_DEFAULT_CONTAINER_NAME,
ECS_DEFAULT_CPU,
ECS_DEFAULT_FAMILY,
ECS_DEFAULT_MEMORY,
AwsCredentials,
ECSJobConfiguration,
Expand Down Expand Up @@ -648,6 +649,7 @@ async def test_task_definition_arn(aws_credentials: AwsCredentials, flow_run: Fl
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
print(task)
assert (
task["taskDefinitionArn"] == task_definition_arn
), "The task definition should be used without registering a new one"
Expand Down Expand Up @@ -2316,3 +2318,27 @@ async def test_mask_sensitive_env_values():
res["overrides"]["containerOverrides"][0]["environment"][1]["value"]
== "NORMAL_VALUE"
)


@pytest.mark.usefixtures("ecs_mocks")
async def test_get_or_generate_family(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
configuration = await construct_configuration(
aws_credentials=aws_credentials,
)

work_pool_name = "test"
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")
family = f"{ECS_DEFAULT_FAMILY}_{work_pool_name}_{flow_run.deployment_id}"

async with ECSWorker(work_pool_name=work_pool_name) as worker:
result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
task_definition = describe_task_definition(ecs_client, task)
assert task_definition["family"] == family

0 comments on commit 219808c

Please sign in to comment.