Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add ECS worker option to use most recent revision in task definition …
Browse files Browse the repository at this point in the history
…family (#370)

Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
kevingrismore and zzstoatzz authored Feb 28, 2024
1 parent f14daa4 commit a2070b2
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 57 deletions.
2 changes: 1 addition & 1 deletion docs/gen_examples_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_code_examples(obj: Union[ModuleType, Callable]) -> Set[str]:
for section in parsed_sections:
if section.kind == DocstringSectionKind.examples:
code_example = "\n".join(
(part[1] for part in section.as_dict().get("value", []))
part[1] for part in section.as_dict().get("value", [])
)
if not skip_block_load_code_example(code_example):
code_examples.add(code_example)
Expand Down
141 changes: 85 additions & 56 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)
cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

@root_validator
def task_run_request_requires_arn_if_no_task_definition_given(cls, values) -> dict:
Expand Down Expand Up @@ -550,6 +551,16 @@ class ECSVariables(BaseVariables):
"your AWS account, instead it will be marked as INACTIVE."
),
)
match_latest_revision_in_family: bool = Field(
default=False,
description=(
"If enabled, the most recent active revision in the task definition "
"family will be compared against the desired ECS task configuration. "
"If they are equal, the existing task definition will be used instead "
"of registering a new one. If no family is specified the default family "
f'"{ECS_DEFAULT_FAMILY}" will be used.'
),
)


class ECSWorkerResult(BaseWorkerResult):
Expand Down Expand Up @@ -661,55 +672,15 @@ def _create_task_and_wait_for_start(
new_task_definition_registered = False

if not task_definition_arn:
cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(
flow_run.deployment_id
)
task_definition = self._prepare_task_definition(
configuration, region=ecs_client.meta.region_name
)

if cached_task_definition_arn:
# Read the task definition to see if the cached task definition is valid
try:
cached_task_definition = self._retrieve_task_definition(
logger, ecs_client, cached_task_definition_arn
)
except Exception as exc:
logger.warning(
"Failed to retrieve cached task definition"
f" {cached_task_definition_arn!r}: {exc!r}"
)
# Clear from cache
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None
else:
if not cached_task_definition["status"] == "ACTIVE":
# Cached task definition is not active
logger.warning(
"Cached task definition"
f" {cached_task_definition_arn!r} is not active"
)
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None
elif not self._task_definitions_equal(
task_definition, cached_task_definition
):
# Cached task definition is not valid
logger.warning(
"Cached task definition"
f" {cached_task_definition_arn!r} does not meet"
" requirements"
)
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None

if not cached_task_definition_arn:
task_definition_arn = self._register_task_definition(
logger, ecs_client, task_definition
)
new_task_definition_registered = True
else:
task_definition_arn = cached_task_definition_arn
(
task_definition_arn,
new_task_definition_registered,
) = self._get_or_register_task_definition(
logger, ecs_client, configuration, flow_run, task_definition
)
else:
task_definition = self._retrieve_task_definition(
logger, ecs_client, task_definition_arn
Expand All @@ -722,17 +693,13 @@ def _create_task_and_wait_for_start(

self._validate_task_definition(task_definition, configuration)

# Update the cached task definition ARN to avoid re-registering the task
# definition on this worker unless necessary; registration is agressively
# rate limited by AWS
_TASK_DEFINITION_CACHE[flow_run.deployment_id] = task_definition_arn

logger.info(f"Using ECS task definition {task_definition_arn!r}...")
logger.debug(
f"Task definition {json.dumps(task_definition, indent=2, default=str)}"
)

# Prepare the task run request
task_run_request = self._prepare_task_run_request(
configuration,
task_definition,
Expand All @@ -753,7 +720,6 @@ def _create_task_and_wait_for_start(
self._report_task_run_creation_failure(configuration, task_run_request, exc)
raise

# Raises an exception if the task does not start
logger.info("Waiting for ECS task run to start...")
self._wait_for_task_start(
logger,
Expand All @@ -766,6 +732,65 @@ def _create_task_and_wait_for_start(

return task_arn, cluster_arn, task_definition, new_task_definition_registered

def _get_or_register_task_definition(
self,
logger: logging.Logger,
ecs_client: _ECSClient,
configuration: ECSJobConfiguration,
flow_run: FlowRun,
task_definition: dict,
) -> Tuple[str, bool]:
"""Get or register a task definition for the given flow run.
Returns a tuple of the task definition ARN and a bool indicating if the task
definition is newly registered.
"""

cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(flow_run.deployment_id)
new_task_definition_registered = False

if cached_task_definition_arn:
try:
cached_task_definition = self._retrieve_task_definition(
logger, ecs_client, cached_task_definition_arn
)
if not cached_task_definition[
"status"
] == "ACTIVE" or not self._task_definitions_equal(
task_definition, cached_task_definition
):
cached_task_definition_arn = None
except Exception:
cached_task_definition_arn = None

if (
not cached_task_definition_arn
and configuration.match_latest_revision_in_family
):
family_name = task_definition.get("family", ECS_DEFAULT_FAMILY)
try:
task_definition_from_family = self._retrieve_task_definition_by_family(
logger, ecs_client, family_name
)
if task_definition_from_family and self._task_definitions_equal(
task_definition, task_definition_from_family
):
cached_task_definition_arn = task_definition_from_family[
"taskDefinitionArn"
]
except Exception:
pass

if not cached_task_definition_arn:
task_definition_arn = self._register_task_definition(
logger, ecs_client, task_definition
)
new_task_definition_registered = True
else:
task_definition_arn = cached_task_definition_arn

return task_definition_arn, new_task_definition_registered

def _watch_task_and_get_exit_code(
self,
logger: logging.Logger,
Expand Down Expand Up @@ -928,15 +953,19 @@ def _retrieve_task_definition(
self,
logger: logging.Logger,
ecs_client: _ECSClient,
task_definition_arn: str,
task_definition: str,
):
"""
Retrieve an existing task definition from AWS.
"""
logger.info(f"Retrieving ECS task definition {task_definition_arn!r}...")
response = ecs_client.describe_task_definition(
taskDefinition=task_definition_arn
)
if task_definition.startswith("arn:aws:ecs:"):
logger.info(f"Retrieving ECS task definition {task_definition!r}...")
else:
logger.info(
"Retrieving most recent active revision from "
f"ECS task family {task_definition!r}..."
)
response = ecs_client.describe_task_definition(taskDefinition=task_definition)
return response["taskDefinition"]

def _wait_for_task_start(
Expand Down
91 changes: 91 additions & 0 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,97 @@ async def test_deregister_task_definition_does_not_apply_to_linked_arn(
describe_task_definition(ecs_client, task)["status"] == "ACTIVE"


@pytest.mark.usefixtures("ecs_mocks")
async def test_match_latest_revision_in_family(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")

configuration_1 = await construct_configuration(
aws_credentials=aws_credentials,
)

configuration_2 = await construct_configuration(
aws_credentials=aws_credentials,
execution_role_arn="test",
)

configuration_3 = await construct_configuration(
aws_credentials=aws_credentials,
match_latest_revision_in_family=True,
execution_role_arn="test",
)

# Let the first worker run and register two task definitions
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration_1, flow_run)
result_1 = await run_then_stop_task(worker, configuration_2, flow_run)

# Start a new worker with an empty cache
async with ECSWorker(work_pool_name="test") as worker:
result_2 = await run_then_stop_task(worker, configuration_3, flow_run)

assert result_1.status_code == 0
_, task_arn_1 = parse_identifier(result_1.identifier)

assert result_2.status_code == 0
_, task_arn_2 = parse_identifier(result_2.identifier)

task_1 = describe_task(ecs_client, task_arn_1)
task_2 = describe_task(ecs_client, task_arn_2)

assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"]
assert task_2["taskDefinitionArn"].endswith(":2")


@pytest.mark.usefixtures("ecs_mocks")
async def test_match_latest_revision_in_family_custom_family(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")

configuration_1 = await construct_configuration(
aws_credentials=aws_credentials,
family="test-family",
)

configuration_2 = await construct_configuration(
aws_credentials=aws_credentials,
execution_role_arn="test",
family="test-family",
)

configuration_3 = await construct_configuration(
aws_credentials=aws_credentials,
match_latest_revision_in_family=True,
execution_role_arn="test",
family="test-family",
)

# Let the first worker run and register two task definitions
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration_1, flow_run)
result_1 = await run_then_stop_task(worker, configuration_2, flow_run)

# Start a new worker with an empty cache
async with ECSWorker(work_pool_name="test") as worker:
result_2 = await run_then_stop_task(worker, configuration_3, flow_run)

assert result_1.status_code == 0
_, task_arn_1 = parse_identifier(result_1.identifier)

assert result_2.status_code == 0
_, task_arn_2 = parse_identifier(result_2.identifier)

task_1 = describe_task(ecs_client, task_arn_1)
task_2 = describe_task(ecs_client, task_arn_2)

assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"]
assert task_2["taskDefinitionArn"].endswith(":2")


@pytest.mark.usefixtures("ecs_mocks")
async def test_worker_caches_registered_task_definitions(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down

0 comments on commit a2070b2

Please sign in to comment.