diff --git a/CHANGES/8680.bugfix.rst b/CHANGES/8680.bugfix.rst new file mode 100644 index 00000000000..2149f12aaaf --- /dev/null +++ b/CHANGES/8680.bugfix.rst @@ -0,0 +1 @@ +Fixed a race closing the server-side WebSocket where the close code would not reach the client. -- by :user:`bdraco`. diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 3b76ab8eead..d0f9e8de8d2 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -431,23 +431,10 @@ async def close( if self._writer is None: raise RuntimeError("Call .prepare() first") - self._cancel_heartbeat() - reader = self._reader - assert reader is not None - - # we need to break `receive()` cycle first, - # `close()` may be called from different task - if self._waiting and not self._closed: - if not self._close_wait: - assert self._loop is not None - self._close_wait = self._loop.create_future() - reader.feed_data(WS_CLOSING_MESSAGE) - await self._close_wait - if self._closed: return False - self._set_closed() + try: await self._writer.close(code, message) writer = self._payload_writer @@ -462,12 +449,21 @@ async def close( self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True + reader = self._reader + assert reader is not None + # we need to break `receive()` cycle before we can call + # `reader.read()` as `close()` may be called from different task + if self._waiting: + assert self._loop is not None + assert self._close_wait is None + self._close_wait = self._loop.create_future() + reader.feed_data(WS_CLOSING_MESSAGE) + await self._close_wait + if self._closing: self._close_transport() return True - reader = self._reader - assert reader is not None try: async with async_timeout.timeout(self._timeout): msg = await reader.read() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index af7addf29b9..394cbf23355 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -4,6 +4,7 @@ import asyncio import contextlib import sys +import weakref from typing import Any, Optional import pytest @@ -11,6 +12,7 @@ import aiohttp from aiohttp import WSServerHandshakeError, web from aiohttp.http import WSCloseCode, WSMsgType +from aiohttp.pytest_plugin import AiohttpClient async def test_websocket_can_prepare(loop: Any, aiohttp_client: Any) -> None: @@ -1019,3 +1021,61 @@ async def handler(request): await ws.close(code=WSCloseCode.OK, message="exit message") await closed + + +async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None: + """Test that the client websocket gets the close message when the server is shutting down.""" + url = "/ws" + app = web.Application() + websockets = web.AppKey("websockets", weakref.WeakSet) + app[websockets] = weakref.WeakSet() + + # need for send signal shutdown server + shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet) + app[shutdown_websockets] = weakref.WeakSet() + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + websocket = web.WebSocketResponse() + await websocket.prepare(request) + request.app[websockets].add(websocket) + request.app[shutdown_websockets].add(websocket) + + try: + async for message in websocket: + await websocket.send_json({"ok": True, "message": message.json()}) + finally: + request.app[websockets].discard(websocket) + + return websocket + + async def on_shutdown(app: web.Application) -> None: + while app[shutdown_websockets]: + websocket = app[shutdown_websockets].pop() + await websocket.close( + code=aiohttp.WSCloseCode.GOING_AWAY, + message="Server shutdown", + ) + + app.router.add_get(url, websocket_handler) + app.on_shutdown.append(on_shutdown) + + client = await aiohttp_client(app) + + websocket = await client.ws_connect(url) + + message = {"message": "hi"} + await websocket.send_json(message) + reply = await websocket.receive_json() + assert reply == {"ok": True, "message": message} + + await app.shutdown() + + assert websocket.closed is False + + reply = await websocket.receive() + + assert reply.type is aiohttp.http.WSMsgType.CLOSE + assert reply.data == aiohttp.WSCloseCode.GOING_AWAY + assert reply.extra == "Server shutdown" + + assert websocket.closed is True