diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index ac5dd333..6368748f 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -126,6 +126,8 @@ from prefect.utilities.pydantic import JsonPatch from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.utilities import assemble_document_for_patches + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import Field, root_validator, validator else: @@ -739,6 +741,23 @@ async def generate_work_pool_base_job_template(self) -> dict: ) if self.task_customizations: + network_config_patches = JsonPatch( + [ + patch + for patch in self.task_customizations + if "networkConfiguration" in patch["path"] + ] + ) + minimal_network_config = assemble_document_for_patches( + network_config_patches + ) + if minimal_network_config: + minimal_network_config_with_patches = network_config_patches.apply( + minimal_network_config + ) + base_job_template["variables"]["properties"]["network_configuration"][ + "default" + ] = minimal_network_config_with_patches["networkConfiguration"] try: base_job_template["job_configuration"]["task_run_request"] = ( self.task_customizations.apply( diff --git a/prefect_aws/utilities.py b/prefect_aws/utilities.py index ad1e6ed2..33b6cdc6 100644 --- a/prefect_aws/utilities.py +++ b/prefect_aws/utilities.py @@ -1,5 +1,7 @@ """Utilities for working with AWS services.""" +from typing import Dict, List, Union + from prefect.utilities.collections import visit_collection @@ -33,3 +35,82 @@ def make_hashable(item): collection, visit_fn=make_hashable, return_data=True ) return hash(hashable_collection) + + +def ensure_path_exists(doc: Union[Dict, List], path: List[str]): + """ + Ensures the path exists in the document, creating empty dictionaries or lists as + needed. + + Args: + doc: The current level of the document or sub-document. + path: The remaining path parts to ensure exist. + """ + if not path: + return + current_path = path.pop(0) + # Check if the next path part exists and is a digit + next_path_is_digit = path and path[0].isdigit() + + # Determine if the current path is for an array or an object + if isinstance(doc, list): # Path is for an array index + current_path = int(current_path) + # Ensure the current level of the document is a list and long enough + + while len(doc) <= current_path: + doc.append({}) + next_level = doc[current_path] + else: # Path is for an object + if current_path not in doc or ( + next_path_is_digit and not isinstance(doc.get(current_path), list) + ): + doc[current_path] = [] if next_path_is_digit else {} + next_level = doc[current_path] + + ensure_path_exists(next_level, path) + + +def assemble_document_for_patches(patches): + """ + Assembles an initial document that can successfully accept the given JSON Patch + operations. + + Args: + patches: A list of JSON Patch operations. + + Returns: + An initial document structured to accept the patches. + + Example: + + ```python + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"} + ] + + initial_document = assemble_document_for_patches(patches) + + #output + { + "name": {}, + "contact": {}, + "age": {} + } + ``` + """ + document = {} + + for patch in patches: + operation = patch["op"] + path = patch["path"].lstrip("/").split("/") + + if operation == "add": + # Ensure all but the last element of the path exists + ensure_path_exists(document, path[:-1]) + elif operation in ["remove", "replace"]: + # For remove adn replace, the entire path should exist + ensure_path_exists(document, path) + + return document diff --git a/tests/test_ecs.py b/tests/test_ecs.py index 2f970116..a81c446b 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -2128,6 +2128,15 @@ def base_job_template_with_defaults(default_base_job_template, aws_credentials): base_job_template_with_defaults["variables"]["properties"][ "auto_deregister_task_definition" ]["default"] = False + base_job_template_with_defaults["variables"]["properties"]["network_configuration"][ + "default" + ] = { + "awsvpcConfiguration": { + "subnets": ["subnet-***"], + "assignPublicIp": "DISABLED", + "securityGroups": ["sg-***"], + } + } return base_job_template_with_defaults @@ -2188,10 +2197,20 @@ async def test_generate_work_pool_base_job_template( cpu=2048, memory=4096, task_customizations=[ + { + "op": "replace", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/subnets", + "value": ["subnet-***"], + }, { "op": "add", "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", - "value": ["sg-d72e9599956a084f5"], + "value": ["sg-***"], }, ], family="test-family", @@ -2229,10 +2248,3 @@ async def test_generate_work_pool_base_job_template( template = await job.generate_work_pool_base_job_template() assert template == expected_template - - if job_config == "custom": - assert ( - "Unable to apply task customizations to the base job template." - "You may need to update the template manually." - in caplog.text - ) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 0e0fdc6f..cecf863f 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,10 @@ import pytest -from prefect_aws.utilities import hash_collection +from prefect_aws.utilities import ( + assemble_document_for_patches, + ensure_path_exists, + hash_collection, +) class TestHashCollection: @@ -32,3 +36,56 @@ def test_unhashable_structure(self): assert hash_collection(typically_unhashable_structure) == hash_collection( typically_unhashable_structure ), "Unhashable structure hashing failed after transformation" + + +class TestAssembleDocumentForPatches: + def test_initial_document(self): + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"}, + ] + + initial_document = assemble_document_for_patches(patches) + + expected_document = {"name": {}, "contact": {}, "age": {}} + + assert initial_document == expected_document, "Initial document assembly failed" + + +class TestEnsurePathExists: + def test_existing_path(self): + doc = {"key1": {"subkey1": "value1"}} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": {"subkey1": "value1"} + }, "Existing path modification failed" + + def test_new_path_object(self): + doc = {} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == {"key1": {"subkey1": {}}}, "New path creation for object failed" + + def test_new_path_array(self): + doc = {} + path = ["key1", "0"] + ensure_path_exists(doc, path) + assert doc == {"key1": [{}]}, "New path creation for array failed" + + def test_existing_path_array(self): + doc = {"key1": [{"subkey1": "value1"}]} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": "value1"}] + }, "Existing path modification for array failed" + + def test_existing_path_array_index_out_of_range(self): + doc = {"key1": []} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": {}}] + }, "Existing path modification for array index out of range failed"