From f758e3ed8269fd046597b9a0dcbd1dca59a0ac14 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Mon, 17 Jun 2024 12:10:45 -0700 Subject: [PATCH] [sapphire] Set listening socket to non-blocking This will prevent a hang due to race between select() and accept(). --- grizzly/common/runner.py | 3 ++- sapphire/core.py | 23 +++++++++++++++-------- sapphire/test_sapphire.py | 24 ++++++++++++++++-------- sapphire/test_worker.py | 1 + sapphire/worker.py | 3 +++ 5 files changed, 37 insertions(+), 17 deletions(-) diff --git a/grizzly/common/runner.py b/grizzly/common/runner.py index 0aa975e0..bad0cd80 100644 --- a/grizzly/common/runner.py +++ b/grizzly/common/runner.py @@ -159,7 +159,8 @@ def launch(self, location: str, max_retries: int = 3, retry_delay: int = 0) -> N assert self._target is not None assert max_retries >= 0 assert retry_delay >= 0 - self._server.clear_backlog() + # nothing should be trying to connect, did the previous target.close() fail? + assert self._server.clear_backlog() self._tests_run = 0 self.startup_failure = False launch_duration: float = 0 diff --git a/sapphire/core.py b/sapphire/core.py index 656cb4fa..cec8fdbb 100644 --- a/sapphire/core.py +++ b/sapphire/core.py @@ -9,7 +9,7 @@ from pathlib import Path from socket import SO_REUSEADDR, SOL_SOCKET, gethostname, socket from ssl import PROTOCOL_TLS_SERVER, SSLContext, SSLSocket -from time import sleep, time +from time import perf_counter, sleep from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast from .certificate_bundle import CertificateBundle @@ -83,6 +83,8 @@ def create_listening_socket( try: sock.bind(("0.0.0.0" if remote else "127.0.0.1", port)) sock.listen(5) + # put socket in non-blocking mode + sock.settimeout(0) except (OSError, PermissionError) as exc: sock.close() if remaining > 0: @@ -137,31 +139,36 @@ def __enter__(self) -> "Sapphire": def __exit__(self, *exc: Any) -> None: self.close() - def clear_backlog(self) -> None: + def clear_backlog(self, timeout: float = 10) -> bool: """Remove all pending connections from backlog. This should only be called when there isn't anything actively trying to connect. Args: - None + timeout: Maximum number of seconds to run. Returns: - None + True if all connections are cleared from the backlog otherwise False. """ + # this assumes the socket is in non-blocking mode + assert not self._socket.getblocking() LOG.debug("clearing socket backlog") - self._socket.settimeout(0) - deadline = time() + 10 + deadline = perf_counter() + timeout while True: try: self._socket.accept()[0].close() except BlockingIOError: + # no remaining pending connections break except OSError as exc: LOG.debug("Error closing socket: %r", exc) else: LOG.debug("pending socket closed") # if this fires something is likely actively trying to connect - assert deadline > time() - self._socket.settimeout(None) + if deadline <= perf_counter(): + return False + # avoid hogging the cpu + sleep(0.1) + return True def close(self) -> None: """Close listening server socket. diff --git a/sapphire/test_sapphire.py b/sapphire/test_sapphire.py index 982772f0..72d378ab 100644 --- a/sapphire/test_sapphire.py +++ b/sapphire/test_sapphire.py @@ -6,7 +6,7 @@ import socket from hashlib import sha1 -from itertools import repeat +from itertools import count, repeat from os import urandom from pathlib import Path from platform import system @@ -700,16 +700,24 @@ def test_sapphire_26(client, tmp_path): def test_sapphire_27(mocker): """test Sapphire.clear_backlog()""" - mocker.patch("sapphire.core.socket", autospec=True) - mocker.patch("sapphire.core.time", autospec=True, return_value=1) + mocker.patch("sapphire.core.perf_counter", autospec=True, side_effect=count()) + mocker.patch("sapphire.core.sleep", autospec=True) + # test clearing backlog pending = mocker.Mock(spec_set=socket.socket) + pending.accept.side_effect = ((pending, None), OSError, BlockingIOError) + pending.getblocking.return_value = False + pending.getsockname.return_value = (None, 1337) + mocker.patch("sapphire.core.socket", return_value=pending) with Sapphire(timeout=10) as serv: - serv._socket = mocker.Mock(spec_set=socket.socket) - serv._socket.accept.side_effect = ((pending, None), OSError, BlockingIOError) - serv.clear_backlog() + assert serv.clear_backlog() assert serv._socket.accept.call_count == 3 - assert serv._socket.settimeout.call_count == 2 - assert pending.close.call_count == 1 + assert pending.close.call_count == 1 + pending.reset_mock() + # test hang + pending.accept.side_effect = None + pending.accept.return_value = (pending, None) + with Sapphire(timeout=1) as serv: + assert not serv.clear_backlog() @mark.skipif(system() != "Windows", reason="Only supported on Windows") diff --git a/sapphire/test_worker.py b/sapphire/test_worker.py index 69fcd510..7f439930 100644 --- a/sapphire/test_worker.py +++ b/sapphire/test_worker.py @@ -45,6 +45,7 @@ def test_worker_01(mocker): [ socket.timeout("test"), OSError("test"), + BlockingIOError("test"), ], ) def test_worker_02(mocker, exc): diff --git a/sapphire/worker.py b/sapphire/worker.py index 109ef0d4..0732e525 100644 --- a/sapphire/worker.py +++ b/sapphire/worker.py @@ -287,6 +287,9 @@ def launch( job.accepting.clear() w_thread.start() return cls(conn, w_thread) + except BlockingIOError: + # accept() can block because of race between select() and accept() + pass except OSError as exc: LOG.debug("worker thread not launched: %s", exc) except ThreadError: