diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index e84d5b84..a6df49c5 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -267,7 +267,8 @@ class ECSJobConfiguration(BaseJobConfiguration): auto_deregister_task_definition: bool = Field(default=False) vpc_id: Optional[str] = Field(default=None) container_name: Optional[str] = Field(default=None) - + container_cpu: Optional[int] = Field(default=None) + container_memory: Optional[int] = Field(default=None) cluster: Optional[str] = Field(default=None) match_latest_revision_in_family: bool = Field(default=False) @@ -1327,12 +1328,9 @@ def _prepare_task_definition( # CPU and memory are required in some cases, retrieve the value to use task_cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU task_memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY - container_cpu = configuration.task_run_request["overrides"][ - "containerOverrides" - ][0].get("cpu") - container_memory = configuration.task_run_request["overrides"][ - "containerOverrides" - ][0].get("cpu") + container_cpu = configuration.container_cpu or task_cpu + container_memory = configuration.container_memory or task_memory + launch_type = configuration.task_run_request.get( "launchType", ECS_DEFAULT_LAUNCH_TYPE ) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 46471502..92e8e377 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -344,6 +344,7 @@ async def construct_configuration_with_job_template( print(f"Using variables: {variables.json(indent=2)}") base_template = ECSWorker.get_default_base_job_template() + print(f"Using base template configuration: {json.dumps(base_template, indent=2)}") for key in template_overrides: base_template["job_configuration"][key] = template_overrides[key]