Skip to content

Commit

Permalink
Merge pull request #238 from noqdev/bug/en-1879-limit-gather-parallel…
Browse files Browse the repository at this point in the history
…ization

EN-1879: Limit parallelization on initial asyncio.gather call
  • Loading branch information
castrapel authored Mar 14, 2023
2 parents 1b8b301 + 1f03d45 commit 3097a0f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ push_manifest:

.PHONY: test
test:
python3.10 -m pytest --cov iambic --cov-report xml:cov_unit_tests.xml --cov-report html:cov_unit_tests.html . --ignore functional_tests/ -s
python3.10 -m pytest --cov iambic --cov-report lcov:cov_unit_tests.lcov --cov-report xml:cov_unit_tests.xml --cov-report html:cov_unit_tests.html . --ignore functional_tests/ -s

.PHONY: functional_test
functional_test:
pytest --cov-report html --cov iambic --cov-report xml:cov_functional_tests.xml --cov-report html:cov_functional_tests.html functional_tests --ignore functional_tests/test_github_cicd.py -s
pytest --cov-report html --cov iambic --cov-report lcov:cov_functional_tests.lcov --cov-report xml:cov_functional_tests.xml --cov-report html:cov_functional_tests.html functional_tests --ignore functional_tests/test_github_cicd.py -s
# pytest --cov-report html --cov iambic functional_tests -s
# pytest --cov-report html --cov iambic functional_tests/aws/role/test_create_template.py -s
# pytest --cov-report html --cov iambic functional_tests/aws/managed_policy/test_template_expiration.py -s
Expand Down
86 changes: 86 additions & 0 deletions iambic/core/aio_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import asyncio
from typing import Any, Awaitable, TypeVar

T = TypeVar("T")


async def gather_limit(
*args: Awaitable[T],
return_exceptions: bool = False,
limit: int = -1,
) -> list[Any]:
"""
(Taken from https://github.com/omnilib/aioitertools/blob/v0.7.1/aioitertools/asyncio.py)
Like asyncio.gather but with a limit on concurrency.
Note that all results are buffered.
If gather is cancelled all tasks that were internally created and still pending
will be cancelled as well.
Example::
futures = [some_coro(i) for i in range(10)]
results = await gather(*futures, limit=2)
"""

# For detecting input duplicates and reconciling them at the end
input_map: dict[Awaitable[T], list[int]] = {}
# This is keyed on what we'll get back from asyncio.wait
pos: dict[asyncio.Future[T], int] = {}
ret: list[Any] = [None] * len(args)

pending: set[asyncio.Future[T]] = set()
done: set[asyncio.Future[T]] = set()

next_arg = 0

while True:
while next_arg < len(args) and (limit == -1 or len(pending) < limit):
# We have to defer the creation of the Task as long as possible
# because once we do, it starts executing, regardless of what we
# have in the pending set.
if args[next_arg] in input_map:
input_map[args[next_arg]].append(next_arg)
else:
# We call ensure_future directly to ensure that we have a Task
# because the return value of asyncio.wait will be an implicit
# task otherwise, and we won't be able to know which input it
# corresponds to.
task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg])
pending.add(task)
pos[task] = next_arg
input_map[args[next_arg]] = [next_arg]
next_arg += 1

# pending might be empty if the last items of args were dupes;
# asyncio.wait([]) will raise an exception.
if pending:
try:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
for x in done:
if return_exceptions and x.exception():
ret[pos[x]] = x.exception()
else:
ret[pos[x]] = x.result()
except asyncio.CancelledError:
# Since we created these tasks we should cancel them
for x in pending:
x.cancel()
# we insure that all tasks are cancelled before we raise
await asyncio.gather(*pending, return_exceptions=True)
raise

if not pending and next_arg == len(args):
break

for lst in input_map.values():
for i in range(1, len(lst)):
ret[lst[i]] = ret[lst[0]]

return ret
10 changes: 7 additions & 3 deletions iambic/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import aiofiles
from asgiref.sync import sync_to_async
from ruamel.yaml import YAML

from iambic.core import noq_json as json
from iambic.core.aio_utils import gather_limit
from iambic.core.context import ExecutionContext
from iambic.core.exceptions import RateLimitException
from iambic.core.iambic_enum import IambicManaged
from iambic.core.logger import log
from ruamel.yaml import YAML

if TYPE_CHECKING:
from iambic.core.models import ProposedChange
Expand Down Expand Up @@ -149,8 +151,10 @@ async def gather_templates(repo_dir: str, template_type: str = None) -> list[str
file_paths += glob.glob(f"{repo_dir}*.yaml", recursive=True)
file_paths += glob.glob(f"{repo_dir}/**/*.yml", recursive=True)
file_paths += glob.glob(f"{repo_dir}*.yml", recursive=True)
file_paths = await asyncio.gather(
*[file_regex_search(fp, regex_pattern) for fp in file_paths]

file_paths = await gather_limit(
*[file_regex_search(fp, regex_pattern) for fp in file_paths],
limit=int(os.environ.get("IAMBIC_GATHER_TEMPLATES_LIMIT", 10)),
)
return [fp for fp in file_paths if fp]

Expand Down
40 changes: 40 additions & 0 deletions test/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,43 @@ def test_non_date_input(self):
input_value = "not a date"
expected_result = input_value
self.assertEqual(simplify_dt(input_value), expected_result)


@pytest.mark.asyncio
async def test_gather_templates(tmpdir):
from iambic.core.utils import gather_templates

# Create a test directory structure
templates_dir = tmpdir.mkdir("templates")
sub_dir1 = templates_dir.mkdir("sub_dir1")
sub_dir2 = templates_dir.mkdir("sub_dir2")
file1 = templates_dir.join("file1.yml")
file1.write("template_type: NOQ::type1\n")
file2 = sub_dir1.join("file2.yaml")
file2.write("template_type: NOQ::type2\n")
file3 = sub_dir2.join("file3.yml")
file3.write("template_type: NOQ::type1\n")
file4 = sub_dir2.join("file4.yaml")
file4.write("template_type: not_noq\n")

# Test the function
result = await gather_templates(str(templates_dir), "type1")
assert set(result) == {
str(file1),
str(file3),
}

result = await gather_templates(str(templates_dir), "type2")
assert set(result) == {
str(file2),
}

result = await gather_templates(str(templates_dir), "type3")
assert set(result) == set()

result = await gather_templates(str(templates_dir))
assert set(result) == {
str(file1),
str(file2),
str(file3),
}

0 comments on commit 3097a0f

Please sign in to comment.