diff --git a/Makefile b/Makefile index b848f2841..81fc711be 100644 --- a/Makefile +++ b/Makefile @@ -46,11 +46,11 @@ push_manifest: .PHONY: test test: - python3.10 -m pytest --cov iambic --cov-report xml:cov_unit_tests.xml --cov-report html:cov_unit_tests.html . --ignore functional_tests/ -s + python3.10 -m pytest --cov iambic --cov-report lcov:cov_unit_tests.lcov --cov-report xml:cov_unit_tests.xml --cov-report html:cov_unit_tests.html . --ignore functional_tests/ -s .PHONY: functional_test functional_test: - pytest --cov-report html --cov iambic --cov-report xml:cov_functional_tests.xml --cov-report html:cov_functional_tests.html functional_tests --ignore functional_tests/test_github_cicd.py -s + pytest --cov-report html --cov iambic --cov-report lcov:cov_functional_tests.lcov --cov-report xml:cov_functional_tests.xml --cov-report html:cov_functional_tests.html functional_tests --ignore functional_tests/test_github_cicd.py -s # pytest --cov-report html --cov iambic functional_tests -s # pytest --cov-report html --cov iambic functional_tests/aws/role/test_create_template.py -s # pytest --cov-report html --cov iambic functional_tests/aws/managed_policy/test_template_expiration.py -s diff --git a/iambic/core/aio_utils/__init__.py b/iambic/core/aio_utils/__init__.py new file mode 100644 index 000000000..b2cffbfed --- /dev/null +++ b/iambic/core/aio_utils/__init__.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Awaitable, TypeVar + +T = TypeVar("T") + + +async def gather_limit( + *args: Awaitable[T], + return_exceptions: bool = False, + limit: int = -1, +) -> list[Any]: + """ + (Taken from https://github.com/omnilib/aioitertools/blob/v0.7.1/aioitertools/asyncio.py) + Like asyncio.gather but with a limit on concurrency. + + Note that all results are buffered. + + If gather is cancelled all tasks that were internally created and still pending + will be cancelled as well. + + Example:: + + futures = [some_coro(i) for i in range(10)] + + results = await gather(*futures, limit=2) + """ + + # For detecting input duplicates and reconciling them at the end + input_map: dict[Awaitable[T], list[int]] = {} + # This is keyed on what we'll get back from asyncio.wait + pos: dict[asyncio.Future[T], int] = {} + ret: list[Any] = [None] * len(args) + + pending: set[asyncio.Future[T]] = set() + done: set[asyncio.Future[T]] = set() + + next_arg = 0 + + while True: + while next_arg < len(args) and (limit == -1 or len(pending) < limit): + # We have to defer the creation of the Task as long as possible + # because once we do, it starts executing, regardless of what we + # have in the pending set. + if args[next_arg] in input_map: + input_map[args[next_arg]].append(next_arg) + else: + # We call ensure_future directly to ensure that we have a Task + # because the return value of asyncio.wait will be an implicit + # task otherwise, and we won't be able to know which input it + # corresponds to. + task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg]) + pending.add(task) + pos[task] = next_arg + input_map[args[next_arg]] = [next_arg] + next_arg += 1 + + # pending might be empty if the last items of args were dupes; + # asyncio.wait([]) will raise an exception. + if pending: + try: + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + for x in done: + if return_exceptions and x.exception(): + ret[pos[x]] = x.exception() + else: + ret[pos[x]] = x.result() + except asyncio.CancelledError: + # Since we created these tasks we should cancel them + for x in pending: + x.cancel() + # we insure that all tasks are cancelled before we raise + await asyncio.gather(*pending, return_exceptions=True) + raise + + if not pending and next_arg == len(args): + break + + for lst in input_map.values(): + for i in range(1, len(lst)): + ret[lst[i]] = ret[lst[0]] + + return ret diff --git a/iambic/core/utils.py b/iambic/core/utils.py index 9c7b65d7c..a0848d464 100644 --- a/iambic/core/utils.py +++ b/iambic/core/utils.py @@ -16,12 +16,14 @@ import aiofiles from asgiref.sync import sync_to_async +from ruamel.yaml import YAML + from iambic.core import noq_json as json +from iambic.core.aio_utils import gather_limit from iambic.core.context import ExecutionContext from iambic.core.exceptions import RateLimitException from iambic.core.iambic_enum import IambicManaged from iambic.core.logger import log -from ruamel.yaml import YAML if TYPE_CHECKING: from iambic.core.models import ProposedChange @@ -149,8 +151,10 @@ async def gather_templates(repo_dir: str, template_type: str = None) -> list[str file_paths += glob.glob(f"{repo_dir}*.yaml", recursive=True) file_paths += glob.glob(f"{repo_dir}/**/*.yml", recursive=True) file_paths += glob.glob(f"{repo_dir}*.yml", recursive=True) - file_paths = await asyncio.gather( - *[file_regex_search(fp, regex_pattern) for fp in file_paths] + + file_paths = await gather_limit( + *[file_regex_search(fp, regex_pattern) for fp in file_paths], + limit=int(os.environ.get("IAMBIC_GATHER_TEMPLATES_LIMIT", 10)), ) return [fp for fp in file_paths if fp] diff --git a/test/core/test_utils.py b/test/core/test_utils.py index 44d4759ac..9f3cfccd0 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -184,3 +184,43 @@ def test_non_date_input(self): input_value = "not a date" expected_result = input_value self.assertEqual(simplify_dt(input_value), expected_result) + + +@pytest.mark.asyncio +async def test_gather_templates(tmpdir): + from iambic.core.utils import gather_templates + + # Create a test directory structure + templates_dir = tmpdir.mkdir("templates") + sub_dir1 = templates_dir.mkdir("sub_dir1") + sub_dir2 = templates_dir.mkdir("sub_dir2") + file1 = templates_dir.join("file1.yml") + file1.write("template_type: NOQ::type1\n") + file2 = sub_dir1.join("file2.yaml") + file2.write("template_type: NOQ::type2\n") + file3 = sub_dir2.join("file3.yml") + file3.write("template_type: NOQ::type1\n") + file4 = sub_dir2.join("file4.yaml") + file4.write("template_type: not_noq\n") + + # Test the function + result = await gather_templates(str(templates_dir), "type1") + assert set(result) == { + str(file1), + str(file3), + } + + result = await gather_templates(str(templates_dir), "type2") + assert set(result) == { + str(file2), + } + + result = await gather_templates(str(templates_dir), "type3") + assert set(result) == set() + + result = await gather_templates(str(templates_dir)) + assert set(result) == { + str(file1), + str(file2), + str(file3), + }