Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improvement/error handling task creation #406

Merged
merged 9 commits into from
Apr 11, 2024
12 changes: 11 additions & 1 deletion prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,14 +1595,24 @@ def _prepare_task_run_request(
CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS,
CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS,
),
reraise=True,
)
def _create_task_run(self, ecs_client: _ECSClient, task_run_request: dict) -> str:
"""
Create a run of a task definition.

Returns the task run ARN.
"""
return ecs_client.run_task(**task_run_request)["tasks"][0]
task = ecs_client.run_task(**task_run_request)
if task["failures"]:
raise RuntimeError(
f"Failed to run ECS task: {task['failures'][0]['reason']}"
)
elif not task["tasks"]:
raise RuntimeError(
"Failed to run ECS task: no tasks or failures were returned."
)
return task["tasks"][0]

def _task_definitions_equal(self, taskdef_1, taskdef_2) -> bool:
"""
Expand Down
40 changes: 37 additions & 3 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional
from unittest.mock import ANY, MagicMock
from unittest.mock import patch as mock_patch
from uuid import uuid4

import anyio
import botocore
import pytest
import yaml
from moto import mock_ec2, mock_ecs, mock_logs
Expand All @@ -19,8 +21,6 @@
else:
from pydantic import ValidationError

from tenacity import RetryError

from prefect_aws.credentials import _get_client_cached
from prefect_aws.workers.ecs_worker import (
_TASK_DEFINITION_CACHE,
Expand Down Expand Up @@ -1323,6 +1323,40 @@ async def write_fake_log(task_arn):
# assert "test-message-{i}" in err


orig = botocore.client.BaseClient._make_api_call


def mock_make_api_call(self, operation_name, kwarg):
if operation_name == "RunTask":
return {
"failures": [
{"arn": "string", "reason": "string", "detail": "string"},
]
}
return orig(self, operation_name, kwarg)


@pytest.mark.usefixtures("ecs_mocks")
async def test_run_task_error_handling(
aws_credentials: AwsCredentials,
flow_run: FlowRun,
capsys,
):
configuration = await construct_configuration(
aws_credentials=aws_credentials,
task_role_arn="test",
)

with mock_patch(
"botocore.client.BaseClient._make_api_call", new=mock_make_api_call
):
async with ECSWorker(work_pool_name="test") as worker:
with pytest.raises(RuntimeError, match="Failed to run ECS task") as exc:
await run_then_stop_task(worker, configuration, flow_run)

assert exc.value.args[0] == "Failed to run ECS task: string"


@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize(
"cloudwatch_logs_options",
Expand Down Expand Up @@ -2283,7 +2317,7 @@ async def test_retry_on_failed_task_start(
},
)

with pytest.raises(RetryError):
with pytest.raises(RuntimeError):
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration, flow_run)

Expand Down
Loading