Skip to content

Commit

Permalink
[sapphire] Avoid accepting connections prematurely
Browse files Browse the repository at this point in the history
  • Loading branch information
tysmith committed May 8, 2024
1 parent 962079c commit 3e1a06c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 36 deletions.
1 change: 0 additions & 1 deletion sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def serve(
True unless the timeout is exceeded.
"""
assert self._job.pending or self._job.forever
assert self._socket.gettimeout() is not None
assert shutdown_delay >= 0
assert timeout >= 0
if continue_cb is not None and not callable(continue_cb):
Expand Down
13 changes: 2 additions & 11 deletions sapphire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def create_listening_socket(
attempts: int = 10,
port: int = 0,
remote: bool = False,
timeout: Optional[float] = None,
) -> socket:
"""Create listening socket. Search for an open socket if needed and
configure the socket. If a specific port is unavailable or no
Expand All @@ -64,20 +63,16 @@ def create_listening_socket(
attempts: Number of attempts to configure the socket.
port: Port to listen on. Use 0 for system assigned port.
remote: Accept all (non-local) incoming connections.
timeout: Used to set socket timeout.
Returns:
A listening socket.
"""
assert attempts > 0
assert 0 <= port <= 65535
assert timeout is None or timeout > 0

for remaining in reversed(range(attempts)):
sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if timeout is not None:
sock.settimeout(timeout)
# attempt to bind/listen
try:
sock.bind(("0.0.0.0" if remote else "127.0.0.1", port))
Expand Down Expand Up @@ -118,11 +113,7 @@ def __init__(
assert timeout >= 0
self._auto_close = auto_close # call 'window.close()' on 4xx error pages
self._max_workers = max_workers # limit worker threads
sock = create_listening_socket(
port=port,
remote=allow_remote,
timeout=self.LISTEN_TIMEOUT,
)
sock = create_listening_socket(port=port, remote=allow_remote)
# enable https if certificates are provided
if certs:
context = SSLContext(PROTOCOL_TLS_SERVER)
Expand Down Expand Up @@ -166,7 +157,7 @@ def clear_backlog(self) -> None:
LOG.debug("pending socket closed")
# if this fires something is likely actively trying to connect
assert deadline > time()
self._socket.settimeout(self.LISTEN_TIMEOUT)
self._socket.settimeout(None)

def close(self) -> None:
"""Close listening server socket.
Expand Down
5 changes: 5 additions & 0 deletions sapphire/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_connection_manager_01(mocker, tmp_path, timeout):
clnt_sock.recv.return_value = b"GET /testfile HTTP/1.1"
serv_sock = mocker.Mock(spec_set=socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
assert not job.is_complete()
with ConnectionManager(job, serv_sock) as mgr:
assert mgr.serve(timeout)
Expand Down Expand Up @@ -51,6 +52,7 @@ def test_connection_manager_02(mocker, tmp_path, worker_limit):
)
serv_sock = mocker.Mock(spec_set=socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
assert not job.is_complete()
with ConnectionManager(job, serv_sock, limit=worker_limit) as mgr:
assert mgr.serve(10)
Expand All @@ -66,6 +68,7 @@ def test_connection_manager_03(mocker, tmp_path):
clnt_sock.recv.side_effect = Exception("worker exception")
serv_sock = mocker.Mock(spec_set=socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
with raises(Exception, match="worker exception"):
with ConnectionManager(job, serv_sock) as mgr:
mgr.serve(10)
Expand Down Expand Up @@ -97,6 +100,7 @@ def test_connection_manager_05(mocker, tmp_path):
clnt_sock.recv.return_value = b""
serv_sock = mocker.Mock(spec_set=socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
job = Job(tmp_path, required_files=["file"])
with ConnectionManager(job, serv_sock, poll=0.01) as mgr:
assert not mgr.serve(10)
Expand All @@ -111,6 +115,7 @@ def test_connection_manager_06(mocker, tmp_path):
clnt_sock = mocker.Mock(spec_set=socket)
serv_sock = mocker.Mock(spec_set=socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
job = Job(tmp_path, required_files=["file"])
mocker.patch.object(job, "worker_complete")
with ConnectionManager(job, serv_sock) as mgr:
Expand Down
3 changes: 1 addition & 2 deletions sapphire/test_sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,10 +745,9 @@ def test_create_listening_socket_01(mocker, bind):
fake_sock = mocker.patch("sapphire.core.socket", autospec=True)
fake_sock.return_value.bind.side_effect = bind
bind_calls = len(bind)
assert create_listening_socket(timeout=0.25)
assert create_listening_socket()
assert fake_sock.return_value.close.call_count == bind_calls - 1
assert fake_sock.return_value.setsockopt.call_count == bind_calls
assert fake_sock.return_value.settimeout.call_count == bind_calls
assert fake_sock.return_value.bind.call_count == bind_calls
assert fake_sock.return_value.listen.call_count == 1
assert fake_sleep.call_count == bind_calls - 1
Expand Down
4 changes: 4 additions & 0 deletions sapphire/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_worker_02(mocker, exc):
serv_con = mocker.Mock(spec_set=socket.socket)
serv_job = mocker.Mock(spec_set=Job)
serv_con.accept.side_effect = exc
mocker.patch("sapphire.worker.select", return_value=([serv_con], None, None))
assert Worker.launch(serv_con, serv_job) is None
assert serv_job.accepting.clear.call_count == 0
assert serv_job.accepting.set.call_count == 0
Expand All @@ -64,6 +65,7 @@ def test_worker_03(mocker):
serv_job = mocker.Mock(spec_set=Job)
conn = mocker.Mock(spec_set=socket.socket)
serv_con.accept.return_value = (conn, None)
mocker.patch("sapphire.worker.select", return_value=([serv_con], None, None))
assert Worker.launch(serv_con, serv_job) is None
assert conn.close.call_count == 1
assert serv_job.accepting.clear.call_count == 0
Expand All @@ -88,6 +90,7 @@ def test_worker_04(mocker, tmp_path, url):
clnt_sock.recv.return_value = f"GET {url} HTTP/1.1".encode()
serv_sock = mocker.Mock(spec_set=socket.socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
worker = Worker.launch(serv_sock, job)
assert worker is not None
try:
Expand Down Expand Up @@ -117,6 +120,7 @@ def test_worker_05(mocker, tmp_path, req, response):
clnt_sock.recv.return_value = req
serv_sock = mocker.Mock(spec_set=socket.socket)
serv_sock.accept.return_value = (clnt_sock, None)
mocker.patch("sapphire.worker.select", return_value=([serv_sock], None, None))
worker = Worker.launch(serv_sock, job)
assert worker is not None
assert worker.join(timeout=10)
Expand Down
45 changes: 23 additions & 22 deletions sapphire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from logging import getLogger
from re import compile as re_compile
from select import select
from socket import SHUT_RDWR, socket
from socket import timeout as sock_timeout # Py3.10 socket.timeout => TimeoutError
from sys import exc_info
Expand Down Expand Up @@ -269,26 +270,26 @@ def launch(
) -> Optional["Worker"]:
assert timeout >= 0
assert job.accepting.is_set()
conn = None
try:
conn, _ = listen_sock.accept()
conn.settimeout(timeout)
# create a worker thread to handle client request
w_thread = Thread(target=cls.handle_request, args=(conn, job))
job.accepting.clear()
w_thread.start()
return cls(conn, w_thread)
except sock_timeout:
# no connections to accept
pass
except OSError as exc:
LOG.debug("worker thread not launched: %s", exc)
except ThreadError:
# reset accepting status
job.accepting.set()
LOG.warning("ThreadError (worker), threads: %d", active_count())
# wait for system resources to free up
sleep(0.1)
if conn is not None:
conn.close()
# TODO: is select() timeout value too short, too long?
readable, _, _ = select([listen_sock], (), (), 0.25)
if listen_sock in readable:
conn = None
try:
conn, _ = listen_sock.accept()
conn.settimeout(timeout)
# create a worker thread to handle client request
w_thread = Thread(target=cls.handle_request, args=(conn, job))
job.accepting.clear()
w_thread.start()
return cls(conn, w_thread)
except OSError as exc:
LOG.debug("worker thread not launched: %s", exc)
except ThreadError:
# reset accepting status
job.accepting.set()
LOG.warning("ThreadError (worker), threads: %d", active_count())
# wait for system resources to free up
sleep(0.1)
if conn is not None:
conn.close()
return None

0 comments on commit 3e1a06c

Please sign in to comment.