diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 02c4117c..afbe631b 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -221,6 +221,31 @@ def parse_identifier(identifier: str) -> ECSIdentifier: return ECSIdentifier(cluster, task) +def mask_sensitive_env_values( + task_run_request: dict, values: List[str], keep_length=3, replace_with="***" +): + for container in task_run_request.get("overrides", {}).get( + "containerOverrides", [] + ): + for env_var in container.get("environment", []): + if ( + "name" not in env_var + or "value" not in env_var + or env_var["name"] not in values + ): + continue + if len(env_var["value"]) > keep_length: + # Replace characters beyond the keep length + env_var["value"] = env_var["value"][:keep_length] + replace_with + return task_run_request + + +def mask_api_key(task_run_request): + return mask_sensitive_env_values( + task_run_request, ["PREFECT_API_KEY"], keep_length=6 + ) + + class ECSJobConfiguration(BaseJobConfiguration): """ Job configuration for an ECS worker. @@ -724,8 +749,10 @@ def _create_task_and_wait_for_start( logger.info("Creating ECS task run...") logger.debug( - f"Task run request {json.dumps(task_run_request, indent=2, default=str)}" + "Task run request" + f"{json.dumps(mask_api_key(task_run_request), indent=2, default=str)}" ) + try: task = self._create_task_run(ecs_client, task_run_request) task_arn = task["taskArn"] diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index b6a39b35..077a178a 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -34,6 +34,7 @@ InfrastructureNotFound, _get_container, get_prefect_image_name, + mask_sensitive_env_values, parse_identifier, ) @@ -2180,3 +2181,27 @@ async def test_retry_on_failed_task_start( await run_then_stop_task(worker, configuration, flow_run) assert run_task_mock.call_count == 3 + + +async def test_mask_sensitive_env_values(): + task_run_request = { + "overrides": { + "containerOverrides": [ + { + "environment": [ + {"name": "PREFECT_API_KEY", "value": "SeNsItiVe VaLuE"}, + {"name": "PREFECT_API_URL", "value": "NORMAL_VALUE"}, + ] + } + ] + } + } + + res = mask_sensitive_env_values(task_run_request, ["PREFECT_API_KEY"], 3, "***") + assert ( + res["overrides"]["containerOverrides"][0]["environment"][0]["value"] == "SeN***" + ) + assert ( + res["overrides"]["containerOverrides"][0]["environment"][1]["value"] + == "NORMAL_VALUE" + )