From 3e1a06c23251e10a2d36f496851709f638937fcf Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Tue, 7 May 2024 11:41:30 -0700 Subject: [PATCH] [sapphire] Avoid accepting connections prematurely --- sapphire/connection_manager.py | 1 - sapphire/core.py | 13 ++------- sapphire/test_connection_manager.py | 5 ++++ sapphire/test_sapphire.py | 3 +- sapphire/test_worker.py | 4 +++ sapphire/worker.py | 45 +++++++++++++++-------------- 6 files changed, 35 insertions(+), 36 deletions(-) diff --git a/sapphire/connection_manager.py b/sapphire/connection_manager.py index 63116813..ee3da6db 100644 --- a/sapphire/connection_manager.py +++ b/sapphire/connection_manager.py @@ -129,7 +129,6 @@ def serve( True unless the timeout is exceeded. """ assert self._job.pending or self._job.forever - assert self._socket.gettimeout() is not None assert shutdown_delay >= 0 assert timeout >= 0 if continue_cb is not None and not callable(continue_cb): diff --git a/sapphire/core.py b/sapphire/core.py index 97afe959..b3f66ff0 100644 --- a/sapphire/core.py +++ b/sapphire/core.py @@ -54,7 +54,6 @@ def create_listening_socket( attempts: int = 10, port: int = 0, remote: bool = False, - timeout: Optional[float] = None, ) -> socket: """Create listening socket. Search for an open socket if needed and configure the socket. If a specific port is unavailable or no @@ -64,20 +63,16 @@ def create_listening_socket( attempts: Number of attempts to configure the socket. port: Port to listen on. Use 0 for system assigned port. remote: Accept all (non-local) incoming connections. - timeout: Used to set socket timeout. Returns: A listening socket. """ assert attempts > 0 assert 0 <= port <= 65535 - assert timeout is None or timeout > 0 for remaining in reversed(range(attempts)): sock = socket() sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - if timeout is not None: - sock.settimeout(timeout) # attempt to bind/listen try: sock.bind(("0.0.0.0" if remote else "127.0.0.1", port)) @@ -118,11 +113,7 @@ def __init__( assert timeout >= 0 self._auto_close = auto_close # call 'window.close()' on 4xx error pages self._max_workers = max_workers # limit worker threads - sock = create_listening_socket( - port=port, - remote=allow_remote, - timeout=self.LISTEN_TIMEOUT, - ) + sock = create_listening_socket(port=port, remote=allow_remote) # enable https if certificates are provided if certs: context = SSLContext(PROTOCOL_TLS_SERVER) @@ -166,7 +157,7 @@ def clear_backlog(self) -> None: LOG.debug("pending socket closed") # if this fires something is likely actively trying to connect assert deadline > time() - self._socket.settimeout(self.LISTEN_TIMEOUT) + self._socket.settimeout(None) def close(self) -> None: """Close listening server socket. diff --git a/sapphire/test_connection_manager.py b/sapphire/test_connection_manager.py index 36da17ad..af168b86 100644 --- a/sapphire/test_connection_manager.py +++ b/sapphire/test_connection_manager.py @@ -22,6 +22,7 @@ def test_connection_manager_01(mocker, tmp_path, timeout): clnt_sock.recv.return_value = b"GET /testfile HTTP/1.1" serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) assert not job.is_complete() with ConnectionManager(job, serv_sock) as mgr: assert mgr.serve(timeout) @@ -51,6 +52,7 @@ def test_connection_manager_02(mocker, tmp_path, worker_limit): ) serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) assert not job.is_complete() with ConnectionManager(job, serv_sock, limit=worker_limit) as mgr: assert mgr.serve(10) @@ -66,6 +68,7 @@ def test_connection_manager_03(mocker, tmp_path): clnt_sock.recv.side_effect = Exception("worker exception") serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) with raises(Exception, match="worker exception"): with ConnectionManager(job, serv_sock) as mgr: mgr.serve(10) @@ -97,6 +100,7 @@ def test_connection_manager_05(mocker, tmp_path): clnt_sock.recv.return_value = b"" serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) job = Job(tmp_path, required_files=["file"]) with ConnectionManager(job, serv_sock, poll=0.01) as mgr: assert not mgr.serve(10) @@ -111,6 +115,7 @@ def test_connection_manager_06(mocker, tmp_path): clnt_sock = mocker.Mock(spec_set=socket) serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) job = Job(tmp_path, required_files=["file"]) mocker.patch.object(job, "worker_complete") with ConnectionManager(job, serv_sock) as mgr: diff --git a/sapphire/test_sapphire.py b/sapphire/test_sapphire.py index 400f4f14..278e5132 100644 --- a/sapphire/test_sapphire.py +++ b/sapphire/test_sapphire.py @@ -745,10 +745,9 @@ def test_create_listening_socket_01(mocker, bind): fake_sock = mocker.patch("sapphire.core.socket", autospec=True) fake_sock.return_value.bind.side_effect = bind bind_calls = len(bind) - assert create_listening_socket(timeout=0.25) + assert create_listening_socket() assert fake_sock.return_value.close.call_count == bind_calls - 1 assert fake_sock.return_value.setsockopt.call_count == bind_calls - assert fake_sock.return_value.settimeout.call_count == bind_calls assert fake_sock.return_value.bind.call_count == bind_calls assert fake_sock.return_value.listen.call_count == 1 assert fake_sleep.call_count == bind_calls - 1 diff --git a/sapphire/test_worker.py b/sapphire/test_worker.py index 4d30ccd5..ee1a9abd 100644 --- a/sapphire/test_worker.py +++ b/sapphire/test_worker.py @@ -51,6 +51,7 @@ def test_worker_02(mocker, exc): serv_con = mocker.Mock(spec_set=socket.socket) serv_job = mocker.Mock(spec_set=Job) serv_con.accept.side_effect = exc + mocker.patch("sapphire.worker.select", return_value=([serv_con], None, None)) assert Worker.launch(serv_con, serv_job) is None assert serv_job.accepting.clear.call_count == 0 assert serv_job.accepting.set.call_count == 0 @@ -64,6 +65,7 @@ def test_worker_03(mocker): serv_job = mocker.Mock(spec_set=Job) conn = mocker.Mock(spec_set=socket.socket) serv_con.accept.return_value = (conn, None) + mocker.patch("sapphire.worker.select", return_value=([serv_con], None, None)) assert Worker.launch(serv_con, serv_job) is None assert conn.close.call_count == 1 assert serv_job.accepting.clear.call_count == 0 @@ -88,6 +90,7 @@ def test_worker_04(mocker, tmp_path, url): clnt_sock.recv.return_value = f"GET {url} HTTP/1.1".encode() serv_sock = mocker.Mock(spec_set=socket.socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) worker = Worker.launch(serv_sock, job) assert worker is not None try: @@ -117,6 +120,7 @@ def test_worker_05(mocker, tmp_path, req, response): clnt_sock.recv.return_value = req serv_sock = mocker.Mock(spec_set=socket.socket) serv_sock.accept.return_value = (clnt_sock, None) + mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None)) worker = Worker.launch(serv_sock, job) assert worker is not None assert worker.join(timeout=10) diff --git a/sapphire/worker.py b/sapphire/worker.py index a90c751e..84683357 100644 --- a/sapphire/worker.py +++ b/sapphire/worker.py @@ -6,6 +6,7 @@ """ from logging import getLogger from re import compile as re_compile +from select import select from socket import SHUT_RDWR, socket from socket import timeout as sock_timeout # Py3.10 socket.timeout => TimeoutError from sys import exc_info @@ -269,26 +270,26 @@ def launch( ) -> Optional["Worker"]: assert timeout >= 0 assert job.accepting.is_set() - conn = None - try: - conn, _ = listen_sock.accept() - conn.settimeout(timeout) - # create a worker thread to handle client request - w_thread = Thread(target=cls.handle_request, args=(conn, job)) - job.accepting.clear() - w_thread.start() - return cls(conn, w_thread) - except sock_timeout: - # no connections to accept - pass - except OSError as exc: - LOG.debug("worker thread not launched: %s", exc) - except ThreadError: - # reset accepting status - job.accepting.set() - LOG.warning("ThreadError (worker), threads: %d", active_count()) - # wait for system resources to free up - sleep(0.1) - if conn is not None: - conn.close() + # TODO: is select() timeout value too short, too long? + readable, _, _ = select([listen_sock], (), (), 0.25) + if listen_sock in readable: + conn = None + try: + conn, _ = listen_sock.accept() + conn.settimeout(timeout) + # create a worker thread to handle client request + w_thread = Thread(target=cls.handle_request, args=(conn, job)) + job.accepting.clear() + w_thread.start() + return cls(conn, w_thread) + except OSError as exc: + LOG.debug("worker thread not launched: %s", exc) + except ThreadError: + # reset accepting status + job.accepting.set() + LOG.warning("ThreadError (worker), threads: %d", active_count()) + # wait for system resources to free up + sleep(0.1) + if conn is not None: + conn.close() return None