Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
test updated
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Mar 25, 2024
1 parent b982163 commit d106379
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
8 changes: 2 additions & 6 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class ECSVariables(BaseVariables):
" for available options. "
),
)
cloudwatch_logs_prefix: str = Field(
cloudwatch_logs_prefix: Optional[str] = Field(
default=None,
description=(
"When `configure_cloudwatch_logs` is enabled, this setting may be used to"
Expand Down Expand Up @@ -722,6 +722,7 @@ def _create_task_and_wait_for_start(
)

try:
print("ttttttttask_run_requestttttt", task_run_request)
task = self._create_task_run(ecs_client, task_run_request)
task_arn = task["taskArn"]
cluster_arn = task["clusterArn"]
Expand Down Expand Up @@ -1587,11 +1588,6 @@ def _create_task_run(self, ecs_client: _ECSClient, task_run_request: dict) -> st
Returns the task run ARN.
"""
run = ecs_client.run_task(**task_run_request)
failures = run["failures"]
if failures:
raise RuntimeError(f"Failed to run ECS task: {failures}")

return ecs_client.run_task(**task_run_request)["tasks"][0]

def _task_definitions_equal(self, taskdef_1, taskdef_2) -> bool:
Expand Down
41 changes: 40 additions & 1 deletion tests/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,6 @@ async def test_cloudwatch_log_options(aws_credentials):
configure_cloudwatch_logs=True,
execution_role_arn="test",
cloudwatch_logs_options={
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
)
Expand Down Expand Up @@ -1260,6 +1259,46 @@ async def test_cloudwatch_log_options(aws_credentials):
assert "logConfiguration" not in container


@pytest.mark.usefixtures("ecs_mocks")
async def test_cloudwatch_log_options_no_new_defintion(aws_credentials):
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")

task = ECSTask(
aws_credentials=aws_credentials,
auto_deregister_task_definition=False,
command=["prefect", "version"],
configure_cloudwatch_logs=True,
execution_role_arn="test",
cloudwatch_logs_options={
"max-buffer-size": "2m",
},
)

task_arn = await run_then_stop_task(task)
task = describe_task(ecs_client, task_arn)
task_definition = describe_task_definition(ecs_client, task)

for container in task_definition["containerDefinitions"]:
if container["name"] == "prefect":
print(container["logConfiguration"])
# Assert that the 'prefect' container has logging configured with user
# provided options
assert container["logConfiguration"] == {
"logDriver": "awslogs",
"options": {
"awslogs-create-group": "true",
"awslogs-group": "prefect",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
}
else:
# Other containers should not be modifed
assert "logConfiguration" not in container


@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize("launch_type", ["FARGATE", "FARGATE_SPOT"])
async def test_bridge_network_mode_warns_on_fargate(aws_credentials, launch_type: str):
Expand Down
27 changes: 20 additions & 7 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,8 +1322,20 @@ async def write_fake_log(task_arn):


@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize(
"cloudwatch_logs_options",
[
{
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
{
"max-buffer-size": "2m",
},
],
)
async def test_cloudwatch_log_options(
aws_credentials: AwsCredentials, flow_run: FlowRun
aws_credentials: AwsCredentials, flow_run: FlowRun, cloudwatch_logs_options: dict
):
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")
Expand All @@ -1332,12 +1344,10 @@ async def test_cloudwatch_log_options(
aws_credentials=aws_credentials,
configure_cloudwatch_logs=True,
execution_role_arn="test",
cloudwatch_logs_options={
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
cloudwatch_logs_options=cloudwatch_logs_options,
)
async with ECSWorker(work_pool_name="test") as worker:
work_pool_name = "test"
async with ECSWorker(work_pool_name=work_pool_name) as worker:
result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
Expand All @@ -1347,6 +1357,9 @@ async def test_cloudwatch_log_options(
task_definition = describe_task_definition(ecs_client, task)

for container in task_definition["containerDefinitions"]:
prefix = f"prefect-logs_{work_pool_name}_{flow_run.deployment_id}"
if cloudwatch_logs_options.get("awslogs-stream-prefix"):
prefix = cloudwatch_logs_options["awslogs-stream-prefix"]
if container["name"] == ECS_DEFAULT_CONTAINER_NAME:
# Assert that the container has logging configured with user
# provided options
Expand All @@ -1356,7 +1369,7 @@ async def test_cloudwatch_log_options(
"awslogs-create-group": "true",
"awslogs-group": "prefect",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "override-prefix",
"awslogs-stream-prefix": prefix,
"max-buffer-size": "2m",
},
}
Expand Down

0 comments on commit d106379

Please sign in to comment.