diff --git a/CHANGES/8608.misc.rst b/CHANGES/8608.misc.rst new file mode 100644 index 00000000000..76e845bf997 --- /dev/null +++ b/CHANGES/8608.misc.rst @@ -0,0 +1,3 @@ +Improved websocket performance when messages are sent or received frequently -- by :user:`bdraco`. + +The WebSocket heartbeat scheduling algorithm was improved to reduce the ``asyncio`` scheduling overhead by decreasing the number of ``asyncio.TimerHandle`` creations and cancellations. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index e6c720b4264..5008c0bc336 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -7,7 +7,7 @@ from .client_exceptions import ClientError, ServerTimeoutError from .client_reqrep import ClientResponse -from .helpers import call_later, set_result +from .helpers import calculate_timeout_when, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -72,6 +72,7 @@ def __init__( self._autoping = autoping self._heartbeat = heartbeat self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None @@ -85,52 +86,64 @@ def __init__( self._reset_heartbeat() def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=( - self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5 - ), - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + loop = self._loop + assert loop is not None + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=( - self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5 - ), + self._heartbeat_cb = None + loop = self._loop + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] + + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) def _pong_not_received(self) -> None: if not self._closed: - self._closed = True + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = ServerTimeoutError() self._response.close() @@ -139,6 +152,22 @@ def _pong_not_received(self) -> None: WSMessage(WSMsgType.ERROR, self._exception, None) ) + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + + def _set_closing(self) -> None: + """Set the connection to closing. + + Cancel any heartbeat timers and set the closing flag. + """ + self._closing = True + self._cancel_heartbeat() + @property def closed(self) -> bool: return self._closed @@ -203,13 +232,12 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo if self._waiting and not self._closing: assert self._loop is not None self._close_wait = self._loop.create_future() - self._closing = True + self._set_closing() self._reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait if not self._closed: - self._cancel_heartbeat() - self._closed = True + self._set_closed() try: await self._writer.close(code, message) except asyncio.CancelledError: @@ -278,7 +306,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: await self.close() return WSMessage(WSMsgType.CLOSED, None, None) except ClientError: - self._closed = True + # Likely ServerDisconnectedError when connection is lost + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE return WS_CLOSED_MESSAGE except WebSocketError as exc: @@ -287,19 +316,19 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc - self._closing = True + self._set_closing() self._close_code = WSCloseCode.ABNORMAL_CLOSURE await self.close() return WSMessage(WSMsgType.ERROR, exc, None) if msg.type is WSMsgType.CLOSE: - self._closing = True + self._set_closing() self._close_code = msg.data # Could be closed elsewhere while awaiting reader if not self._closed and self._autoclose: # type: ignore[redundant-expr] await self.close() elif msg.type is WSMsgType.CLOSING: - self._closing = True + self._set_closing() elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 9811ad05054..496f3ad1d05 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -598,12 +598,23 @@ def call_later( loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> Optional[asyncio.TimerHandle]: - if timeout is not None and timeout > 0: - when = loop.time() + timeout - if timeout > timeout_ceil_threshold: - when = ceil(when) - return loop.call_at(when, cb) - return None + if timeout is None or timeout <= 0: + return None + now = loop.time() + when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) + return loop.call_at(when, cb) + + +def calculate_timeout_when( + loop_time: float, + timeout: float, + timeout_ceiling_threshold: float, +) -> float: + """Calculate when to execute a timeout.""" + when = loop_time + timeout + if timeout > timeout_ceiling_threshold: + return ceil(when) + return when class TimeoutHandle: diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 522d1e0f0a9..7cc206b4ac1 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_exception, set_result +from .helpers import calculate_timeout_when, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -74,6 +74,7 @@ class WebSocketResponse(StreamResponse): "_autoclose", "_autoping", "_heartbeat", + "_heartbeat_when", "_heartbeat_cb", "_pong_heartbeat", "_pong_response_cb", @@ -112,6 +113,7 @@ def __init__( self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat + self._heartbeat_when = 0.0 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 @@ -120,57 +122,76 @@ def __init__( self._max_msg_size = max_msg_size def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - assert self._loop is not None - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=( - self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5 - ), - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + loop = self._loop + assert loop is not None + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - assert self._loop is not None and self._writer is not None - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=( - self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5 - ), + self._heartbeat_cb = None + loop = self._loop + assert loop is not None and self._writer is not None + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] + + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: - self._closed = True + self._set_closed() self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) self._exception = asyncio.TimeoutError() + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions if self._payload_writer is not None: @@ -410,7 +431,7 @@ async def close( if self._closed: return False - self._closed = True + self._set_closed() try: await self._writer.close(code, message) writer = self._payload_writer @@ -454,6 +475,7 @@ def _set_closing(self, code: WSCloseCode) -> None: """Set the close code and mark the connection as closing.""" self._closing = True self._close_code = code + self._cancel_heartbeat() def _set_code_close_transport(self, code: WSCloseCode) -> None: """Set the close code and close the transport.""" @@ -566,5 +588,6 @@ def _cancel(self, exc: BaseException) -> None: # web_protocol calls this from connection_lost # or when the server is shutting down. self._closing = True + self._cancel_heartbeat() if self._reader is not None: set_exception(self._reader, exc) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index bee7333a327..7521bd24ac5 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -9,6 +9,7 @@ import aiohttp from aiohttp import client, hdrs +from aiohttp.client_exceptions import ServerDisconnectedError from aiohttp.client_ws import ClientWSTimeout from aiohttp.http import WS_KEY from aiohttp.streams import EofStream @@ -431,6 +432,39 @@ async def test_close_eofstream( await session.close() +async def test_close_connection_lost( + loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes +) -> None: + """Test the websocket client handles the connection being closed out from under it.""" + mresp = mock.Mock(spec_set=client.ClientResponse) + mresp.status = 101 + mresp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + mresp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.WebSocketWriter"), mock.patch( + "aiohttp.client.os" + ) as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(mresp) + + session = aiohttp.ClientSession() + resp = await session.ws_connect("http://test.org") + assert not resp.closed + + exc = ServerDisconnectedError() + resp._reader.set_exception(exc) + + msg = await resp.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + assert resp.closed + + await session.close() + + async def test_close_exc( loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes ) -> None: diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 1d1de51b854..81bc6e6c7c8 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -662,6 +662,7 @@ async def handler(request: web.Request) -> NoReturn: async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None: + """Test that the connection is closed if no pong is received without sending messages.""" ping_received = False async def handler(request: web.Request) -> NoReturn: @@ -686,6 +687,78 @@ async def handler(request: web.Request) -> NoReturn: assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE +async def test_heartbeat_no_pong_after_receive_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after receiving many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(5): + await ws.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await ws.send_str("test") + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(10): + test_msg = await resp.receive() + assert test_msg.data == "test" + # Connection should be closed roughly after 1.5x heartbeat. + + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + +async def test_heartbeat_no_pong_after_send_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after sending many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(10): + msg = await ws.receive() + assert msg.data == "test" + assert msg.type is aiohttp.WSMsgType.TEXT + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(5): + await resp.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await resp.send_str("test") + # Connection should be closed roughly after 1.5x heartbeat. + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + async def test_heartbeat_no_pong_concurrent_receive( aiohttp_client: AiohttpClient, ) -> None: diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 7d990294840..af7addf29b9 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -715,6 +715,63 @@ async def handler(request): await ws.close() +async def test_heartbeat_no_pong_send_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after sending many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + await ws.send_str("test") + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert msg.data == "test" + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + +async def test_heartbeat_no_pong_receive_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after receiving many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + server_msg = await ws.receive() + assert server_msg.type is aiohttp.WSMsgType.TEXT + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + async def test_server_ws_async_for(loop: Any, aiohttp_server: Any) -> None: closed = loop.create_future()