From 219808c95612dd14ed21f5b4705a916cf22464f3 Mon Sep 17 00:00:00 2001 From: Jean Luciano Date: Thu, 28 Mar 2024 14:00:10 -0500 Subject: [PATCH] Generates family field (#398) --- prefect_aws/workers/ecs_worker.py | 28 +++++++++++++++++++++------- tests/workers/test_ecs_worker.py | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index edbe843d..1a5c3d28 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -673,7 +673,7 @@ def _create_task_and_wait_for_start( if not task_definition_arn: task_definition = self._prepare_task_definition( - configuration, region=ecs_client.meta.region_name + configuration, region=ecs_client.meta.region_name, flow_run=flow_run ) ( task_definition_arn, @@ -1205,10 +1205,28 @@ def _watch_task_run( ) time.sleep(configuration.task_watch_poll_interval) + def _get_or_generate_family(self, task_definition: dict, flow_run: FlowRun) -> str: + """ + Gets or generate a family for the task definition. + """ + family = task_definition.get("family") + if not family: + assert self._work_pool_name and flow_run.deployment_id + family = ( + f"{ECS_DEFAULT_FAMILY}_{self._work_pool_name}_{flow_run.deployment_id}" + ) + slugify( + family, + max_length=255, + regex_pattern=r"[^a-zA-Z0-9-_]+", + ) + return family + def _prepare_task_definition( self, configuration: ECSJobConfiguration, region: str, + flow_run: FlowRun, ) -> dict: """ Prepare a task definition by inferring any defaults and merging overrides. @@ -1269,13 +1287,9 @@ def _prepare_task_definition( }, } - family = task_definition.get("family") or ECS_DEFAULT_FAMILY - task_definition["family"] = slugify( - family, - max_length=255, - regex_pattern=r"[^a-zA-Z0-9-_]+", + task_definition["family"] = self._get_or_generate_family( + task_definition, flow_run ) - # CPU and memory are required in some cases, retrieve the value to use cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 702f2ddd..e4dab38c 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -26,6 +26,7 @@ _TASK_DEFINITION_CACHE, ECS_DEFAULT_CONTAINER_NAME, ECS_DEFAULT_CPU, + ECS_DEFAULT_FAMILY, ECS_DEFAULT_MEMORY, AwsCredentials, ECSJobConfiguration, @@ -648,6 +649,7 @@ async def test_task_definition_arn(aws_credentials: AwsCredentials, flow_run: Fl _, task_arn = parse_identifier(result.identifier) task = describe_task(ecs_client, task_arn) + print(task) assert ( task["taskDefinitionArn"] == task_definition_arn ), "The task definition should be used without registering a new one" @@ -2316,3 +2318,27 @@ async def test_mask_sensitive_env_values(): res["overrides"]["containerOverrides"][0]["environment"][1]["value"] == "NORMAL_VALUE" ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_get_or_generate_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + work_pool_name = "test" + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + family = f"{ECS_DEFAULT_FAMILY}_{work_pool_name}_{flow_run.deployment_id}" + + async with ECSWorker(work_pool_name=work_pool_name) as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["family"] == family