From 9db617a982ee27994bf13c805f9c4f054f05de47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Sep 2024 14:37:08 -0500 Subject: [PATCH] fix: rewrite staggered_race to be race safe (#101) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/aiohappyeyeballs/_staggered.py | 159 ++++++++++++++---- src/aiohappyeyeballs/impl.py | 4 +- src/aiohappyeyeballs/staggered.py | 9 - tests/conftest.py | 32 ++++ tests/test_staggered.py | 86 ++++++++++ tests/test_staggered_cpython.py | 146 ++++++++++++++++ ...st_staggered_cpython_eager_task_factory.py | 96 +++++++++++ 7 files changed, 492 insertions(+), 40 deletions(-) delete mode 100644 src/aiohappyeyeballs/staggered.py create mode 100644 tests/conftest.py create mode 100644 tests/test_staggered.py create mode 100644 tests/test_staggered_cpython.py create mode 100644 tests/test_staggered_cpython_eager_task_factory.py diff --git a/src/aiohappyeyeballs/_staggered.py b/src/aiohappyeyeballs/_staggered.py index b5c6798..dd0efb9 100644 --- a/src/aiohappyeyeballs/_staggered.py +++ b/src/aiohappyeyeballs/_staggered.py @@ -1,17 +1,54 @@ import asyncio import contextlib -from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) +_T = TypeVar("_T") -class _Done(Exception): - pass +def _set_result(wait_next: "asyncio.Future[None]") -> None: + """Set the result of a future if it is not already done.""" + if not wait_next.done(): + wait_next.set_result(None) -_T = TypeVar("_T") + +async def _wait_one( + futures: "Iterable[asyncio.Future[Any]]", + loop: asyncio.AbstractEventLoop, +) -> _T: + """Wait for the first future to complete.""" + wait_next = loop.create_future() + + def _on_completion(fut: "asyncio.Future[Any]") -> None: + if not wait_next.done(): + wait_next.set_result(fut) + + for f in futures: + f.add_done_callback(_on_completion) + + try: + return await wait_next + finally: + for f in futures: + f.remove_done_callback(_on_completion) async def staggered_race( - coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float] + coro_fns: Iterable[Callable[[], Awaitable[_T]]], + delay: Optional[float], + *, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]: """ Run coroutines with staggered start times and take the first to finish. @@ -38,6 +75,7 @@ async def staggered_race( raise Args: + ---- coro_fns: an iterable of coroutine functions, i.e. callables that return a coroutine object when called. Use ``functools.partial`` or lambdas to pass arguments. @@ -45,7 +83,10 @@ async def staggered_race( delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. + loop: the event loop to use. If ``None``, the running loop is used. + Returns: + ------- tuple *(winner_result, winner_index, exceptions)* where - *winner_result*: the result of the winning coroutine, or ``None`` @@ -62,40 +103,100 @@ async def staggered_race( coroutine's entry is ``None``. """ - # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. - winner_result = None - winner_index = None + loop = loop or asyncio.get_running_loop() exceptions: List[Optional[BaseException]] = [] + tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set() async def run_one_coro( - this_index: int, coro_fn: Callable[[], Awaitable[_T]], - this_failed: asyncio.Event, - ) -> None: + this_index: int, + start_next: "asyncio.Future[None]", + ) -> Optional[Tuple[_T, int]]: + """ + Run a single coroutine. + + If the coroutine fails, set the exception in the exceptions list and + start the next coroutine by setting the result of the start_next. + + If the coroutine succeeds, return the result and the index of the + coroutine in the coro_fns list. + + If SystemExit or KeyboardInterrupt is raised, re-raise it. + """ try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): raise except BaseException as e: exceptions[this_index] = e - this_failed.set() # Kickstart the next coroutine - else: - # Store winner's results - nonlocal winner_index, winner_result - assert winner_index is None # noqa: S101 - winner_index = this_index - winner_result = result - raise _Done + _set_result(start_next) # Kickstart the next coroutine + return None + + return result, this_index + start_next_timer: Optional[asyncio.TimerHandle] = None + start_next: Optional[asyncio.Future[None]] + task: asyncio.Task[Optional[Tuple[_T, int]]] + done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]] + coro_iter = iter(coro_fns) + this_index = -1 try: - async with asyncio.TaskGroup() as tg: - for this_index, coro_fn in enumerate(coro_fns): - this_failed = asyncio.Event() + while True: + if coro_fn := next(coro_iter, None): + this_index += 1 exceptions.append(None) - tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) - with contextlib.suppress(TimeoutError): - await asyncio.wait_for(this_failed.wait(), delay) - except* _Done: - pass - - return winner_result, winner_index, exceptions + start_next = loop.create_future() + task = loop.create_task(run_one_coro(coro_fn, this_index, start_next)) + tasks.add(task) + start_next_timer = ( + loop.call_later(delay, _set_result, start_next) if delay else None + ) + elif not tasks: + # We exhausted the coro_fns list and no tasks are running + # so we have no winner and all coroutines failed. + break + + while tasks: + done = await _wait_one( + [*tasks, start_next] if start_next else tasks, loop + ) + if done is start_next: + # The current task has failed or the timer has expired + # so we need to start the next task. + start_next = None + if start_next_timer: + start_next_timer.cancel() + start_next_timer = None + + # Break out of the task waiting loop to start the next + # task. + break + + if TYPE_CHECKING: + assert isinstance(done, asyncio.Task) + + tasks.remove(done) + if winner := done.result(): + return *winner, exceptions + finally: + # We either have: + # - a winner + # - all tasks failed + # - a KeyboardInterrupt or SystemExit. + + # + # If the timer is still running, cancel it. + # + if start_next_timer: + start_next_timer.cancel() + + # + # If there are any tasks left, cancel them and than + # wait them so they fill the exceptions list. + # + for task in tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + return None, None, exceptions diff --git a/src/aiohappyeyeballs/impl.py b/src/aiohappyeyeballs/impl.py index de96da8..ff2c788 100644 --- a/src/aiohappyeyeballs/impl.py +++ b/src/aiohappyeyeballs/impl.py @@ -8,7 +8,7 @@ import sys from typing import List, Optional, Sequence, Union -from . import staggered +from . import _staggered from .types import AddrInfoType if sys.version_info < (3, 8, 2): # noqa: UP036 @@ -86,7 +86,7 @@ async def start_connection( except (RuntimeError, OSError): continue else: # using happy eyeballs - sock, _, _ = await staggered.staggered_race( + sock, _, _ = await _staggered.staggered_race( ( functools.partial( _connect_sock, current_loop, exceptions, addrinfo, local_addr_infos diff --git a/src/aiohappyeyeballs/staggered.py b/src/aiohappyeyeballs/staggered.py deleted file mode 100644 index 6a8b391..0000000 --- a/src/aiohappyeyeballs/staggered.py +++ /dev/null @@ -1,9 +0,0 @@ -import sys - -if sys.version_info > (3, 11): - # https://github.com/python/cpython/issues/124639#issuecomment-2378129834 - from ._staggered import staggered_race -else: - from asyncio.staggered import staggered_race - -__all__ = ["staggered_race"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..32a3c43 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +"""Configuration for the tests.""" + +import asyncio +import threading +from typing import Generator + +import pytest + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads + + +@pytest.fixture(autouse=True) +def verify_no_lingering_tasks( + event_loop: asyncio.AbstractEventLoop, +) -> Generator[None, None, None]: + """Verify that all tasks are cleaned up.""" + tasks_before = asyncio.all_tasks(event_loop) + yield + + tasks = asyncio.all_tasks(event_loop) - tasks_before + for task in tasks: + pytest.fail(f"Task still running: {task!r}") + task.cancel() + if tasks: + event_loop.run_until_complete(asyncio.wait(tasks)) diff --git a/tests/test_staggered.py b/tests/test_staggered.py new file mode 100644 index 0000000..8c5f38b --- /dev/null +++ b/tests/test_staggered.py @@ -0,0 +1,86 @@ +import asyncio +import sys +from functools import partial + +import pytest + +from aiohappyeyeballs._staggered import staggered_race + + +@pytest.mark.asyncio +async def test_one_winners(): + """Test that there is only one winner when there is no await in the coro.""" + winners = [] + + async def coro(idx): + winners.append(idx) + return idx + + coros = [partial(coro, idx) for idx in range(4)] + + winner, index, excs = await staggered_race( + coros, + delay=None, + ) + assert len(winners) == 1 + assert winners == [0] + assert winner == 0 + assert index == 0 + assert excs == [None] + + +@pytest.mark.asyncio +async def test_multiple_winners(): + """Test multiple winners are handled correctly.""" + loop = asyncio.get_running_loop() + winners = [] + finish = loop.create_future() + + async def coro(idx): + await finish + winners.append(idx) + return idx + + coros = [partial(coro, idx) for idx in range(4)] + + task = loop.create_task(staggered_race(coros, delay=0.00001)) + await asyncio.sleep(0.1) + loop.call_soon(finish.set_result, None) + winner, index, excs = await task + assert len(winners) == 4 + assert winners == [0, 1, 2, 3] + assert winner == 0 + assert index == 0 + assert excs == [None, None, None, None] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") +def test_multiple_winners_eager_task_factory(): + """Test multiple winners are handled correctly.""" + loop = asyncio.new_event_loop() + eager_task_factory = asyncio.create_eager_task_factory(asyncio.Task) + loop.set_task_factory(eager_task_factory) + asyncio.set_event_loop(None) + + async def run(): + winners = [] + finish = loop.create_future() + + async def coro(idx): + await finish + winners.append(idx) + return idx + + coros = [partial(coro, idx) for idx in range(4)] + + task = loop.create_task(staggered_race(coros, delay=0.00001)) + await asyncio.sleep(0.1) + loop.call_soon(finish.set_result, None) + winner, index, excs = await task + assert len(winners) == 4 + assert winners == [0, 1, 2, 3] + assert winner == 0 + assert index == 0 + assert excs == [None, None, None, None] + + loop.run_until_complete(run()) diff --git a/tests/test_staggered_cpython.py b/tests/test_staggered_cpython.py new file mode 100644 index 0000000..8607658 --- /dev/null +++ b/tests/test_staggered_cpython.py @@ -0,0 +1,146 @@ +""" +Tests for staggered_race. + +These tests are copied from cpython to ensure our implementation is +compatible with the one in cpython. +""" + +import asyncio +import unittest + +from aiohappyeyeballs._staggered import staggered_race + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class StaggeredTests(unittest.IsolatedAsyncioTestCase): + async def test_empty(self): + winner, index, excs = await staggered_race( + [], + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(excs, []) + + async def test_one_successful(self): + async def coro(index): + return f"Res: {index}" + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, "Res: 0") + self.assertEqual(index, 0) + self.assertEqual(excs, [None]) + + async def test_first_error_second_successful(self): + async def coro(index): + if index == 0: + raise ValueError(index) + return f"Res: {index}" + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, "Res: 1") + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIs(excs[1], None) + + async def test_first_timeout_second_successful(self): + async def coro(index): + if index == 0: + await asyncio.sleep(10) # much bigger than delay + return f"Res: {index}" + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=0.1, + ) + + self.assertEqual(winner, "Res: 1") + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIs(excs[1], None) + + async def test_none_successful(self): + async def coro(index): + raise ValueError(index) + + for delay in [None, 0, 0.1, 1]: + with self.subTest(delay=delay): + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=delay, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsInstance(excs[1], ValueError) + + async def test_long_delay_early_failure(self): + async def coro(index): + await asyncio.sleep(0) # Dummy coroutine for the 1 case + if index == 0: + await asyncio.sleep(0.1) # Dummy coroutine + raise ValueError(index) + + return f"Res: {index}" + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=10, + ) + + self.assertEqual(winner, "Res: 1") + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsNone(excs[1]) + + def test_loop_argument(self): + loop = asyncio.new_event_loop() + + async def coro(): + self.assertEqual(loop, asyncio.get_running_loop()) + return "coro" + + async def main(): + winner, index, excs = await staggered_race([coro], delay=0.1, loop=loop) + + self.assertEqual(winner, "coro") + self.assertEqual(index, 0) + + loop.run_until_complete(main()) + loop.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_staggered_cpython_eager_task_factory.py b/tests/test_staggered_cpython_eager_task_factory.py new file mode 100644 index 0000000..8d3cf28 --- /dev/null +++ b/tests/test_staggered_cpython_eager_task_factory.py @@ -0,0 +1,96 @@ +""" +Tests staggered_race and eager_task_factory with asyncio.Task. + +These tests are copied from cpython to ensure our implementation is +compatible with the one in cpython. +""" + +import asyncio +import sys +import unittest + +from aiohappyeyeballs._staggered import staggered_race + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class EagerTaskFactoryLoopTests(unittest.TestCase): + def close_loop(self, loop): + loop.close() + + def set_event_loop(self, loop, *, cleanup=True): + if loop is None: + raise AssertionError("loop is None") + # ensure that the event loop is passed explicitly in asyncio + asyncio.set_event_loop(None) + if cleanup: + self.addCleanup(self.close_loop, loop) + + def tearDown(self): + asyncio.set_event_loop(None) + self.doCleanups() + + def setUp(self): + if sys.version_info < (3, 12): + self.skipTest("eager_task_factory is only available in Python 3.12+") + + super().setUp() + self.loop = asyncio.new_event_loop() + self.eager_task_factory = asyncio.create_eager_task_factory(asyncio.Task) + self.loop.set_task_factory(self.eager_task_factory) + self.set_event_loop(self.loop) + + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def run(): + winner, index, excs = await staggered_race( + [ + lambda: asyncio.sleep(2, result="sleep2"), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail(), + ], + delay=0.25, + ) + self.assertEqual(winner, "sleep1") + self.assertEqual(index, 1) + assert index is not None + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.loop.run_until_complete(run()) + + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None, + ) + self.assertEqual(winner, "sleep1") + self.assertEqual(index, 1) + assert index is not None + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.loop.run_until_complete(run()) + + +if __name__ == "__main__": + if sys.version_info >= (3, 12): + unittest.main()