From 549c95b948dcddd6588f95545ad6c856f693c503 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 22 Jul 2024 16:09:44 +0100 Subject: [PATCH] Shutdown logic: Only wait on handlers (#8495) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: J. Nick Koston --- aiohttp/web.py | 26 ------------------- aiohttp/web_protocol.py | 8 ++++-- aiohttp/web_runner.py | 16 ++---------- aiohttp/web_server.py | 7 ++++- tests/test_run_app.py | 43 +++++++------------------------ tests/test_web_request_handler.py | 8 +++--- 6 files changed, 29 insertions(+), 79 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 1a30dd87775..68b29c79d0b 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -6,8 +6,6 @@ import warnings from argparse import ArgumentParser from collections.abc import Iterable -from contextlib import suppress -from functools import partial from importlib import import_module from typing import ( Any, @@ -21,7 +19,6 @@ Union, cast, ) -from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey @@ -300,23 +297,6 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: - async def wait( - starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float - ) -> None: - # Wait for pending tasks for a given time limit. - t = asyncio.current_task() - assert t is not None - starting_tasks.add(t) - with suppress(asyncio.TimeoutError): - await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) - - async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: - t = asyncio.current_task() - assert t is not None - exclude.add(t) - while tasks := asyncio.all_tasks().difference(exclude): - await asyncio.wait(tasks) - # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -335,12 +315,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: ) await runner.setup() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in runner.setup()), - # so we just record all running tasks here and exclude them later. - starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) - runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index db15958e88d..1b4e7e66cd8 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -278,7 +278,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None: if self._waiter: self._waiter.cancel() - # wait for handlers + # Wait for graceful disconnection + if self._current_request is not None: + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + async with ceil_timeout(timeout): + await self._current_request.wait_for_disconnection() + # Then cancel handler and wait with suppress(asyncio.CancelledError, asyncio.TimeoutError): async with ceil_timeout(timeout): if self._current_request is not None: @@ -461,7 +466,6 @@ async def _handle_request( start_time: float, request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], ) -> Tuple[StreamResponse, bool]: - assert self._request_handler is not None try: try: self._current_request = request diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 2618875f6bd..f507be60341 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,7 +2,7 @@ import signal import socket from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, List, Optional, Set, Type +from typing import Any, List, Optional, Set, Type from yarl import URL @@ -230,14 +230,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ( - "shutdown_callback", - "_handle_signals", - "_kwargs", - "_server", - "_sites", - "_shutdown_timeout", - ) + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, @@ -246,7 +239,6 @@ def __init__( shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: - self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None @@ -304,10 +296,6 @@ async def cleanup(self) -> None: await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() - - if self.shutdown_callback: - await self.shutdown_callback() - await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 9d317bb12e1..f7dc971c6e1 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -50,7 +50,12 @@ def connection_lost( self, handler: RequestHandler, exc: Optional[BaseException] = None ) -> None: if handler in self._connections: - del self._connections[handler] + if handler._task_handler: + handler._task_handler.add_done_callback( + lambda f: self._connections.pop(handler, None) + ) + else: + del self._connections[handler] def _make_request( self, diff --git a/tests/test_run_app.py b/tests/test_run_app.py index b53637ad436..1c3ba0a6dd5 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -16,7 +16,7 @@ import pytest -from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web +from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -935,8 +935,12 @@ async def test() -> None: async with ClientSession() as sess: for _ in range(5): # pragma: no cover try: - async with sess.get(f"http://localhost:{port}/"): - pass + with pytest.raises(asyncio.TimeoutError): + async with sess.get( + f"http://localhost:{port}/", + timeout=ClientTimeout(total=0.1), + ): + pass except ClientConnectorError: await asyncio.sleep(0.5) else: @@ -956,6 +960,7 @@ async def run_test(app: web.Application) -> None: async def handler(request: web.Request) -> web.Response: nonlocal t t = asyncio.create_task(task()) + await t return web.Response(text="FOO") t = test_task = None @@ -968,7 +973,7 @@ async def handler(request: web.Request) -> web.Response: assert test_task.exception() is None return t - def test_shutdown_wait_for_task( + def test_shutdown_wait_for_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -985,7 +990,7 @@ async def task(): assert t.done() assert not t.cancelled() - def test_shutdown_timeout_task( + def test_shutdown_timeout_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -1002,34 +1007,6 @@ async def task(): assert t.done() assert t.cancelled() - def test_shutdown_wait_for_spawned_task( - self, aiohttp_unused_port: Callable[[], int] - ) -> None: - port = aiohttp_unused_port() - finished = False - finished_sub = False - sub_t = None - - async def sub_task(): - nonlocal finished_sub - await asyncio.sleep(1.5) - finished_sub = True - - async def task(): - nonlocal finished, sub_t - await asyncio.sleep(0.5) - sub_t = asyncio.create_task(sub_task()) - finished = True - - t = self.run_app(port, 3, task) - - assert finished is True - assert t.done() - assert not t.cancelled() - assert finished_sub is True - assert sub_t.done() - assert not sub_t.cancelled() - def test_shutdown_timeout_not_reached( self, aiohttp_unused_port: Callable[[], int] ) -> None: diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 06f99be76c0..4837cab030e 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -22,19 +22,21 @@ async def test_connections() -> None: manager = web.Server(serve) assert manager.connections == [] - handler = object() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None transport = object() manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] - manager.connection_lost(handler, None) # type: ignore[arg-type] + manager.connection_lost(handler, None) assert manager.connections == [] async def test_shutdown_no_timeout() -> None: manager = web.Server(serve) - handler = mock.Mock() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport)