Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deadlock in Sapphire #454

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion grizzly/common/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def launch(self, location: str, max_retries: int = 3, retry_delay: int = 0) -> N
assert self._target is not None
assert max_retries >= 0
assert retry_delay >= 0
self._server.clear_backlog()
# nothing should be trying to connect, did the previous target.close() fail?
assert self._server.clear_backlog()
self._tests_run = 0
self.startup_failure = False
launch_duration: float = 0
Expand Down
7 changes: 1 addition & 6 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,8 @@ def _join_workers(workers: List[Worker], timeout: float = 0) -> List[Worker]:
Returns:
Workers that do not join before the timeout is reached.
"""
assert timeout >= 0
alive = []
deadline = time() + timeout
for worker in workers:
if not worker.join(timeout=max(deadline - time(), 0)):
alive.append(worker)
return alive
return [x for x in workers if not x.join(timeout=max(deadline - time(), 0))]

def serve(
self,
Expand Down
70 changes: 41 additions & 29 deletions sapphire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from socket import SO_REUSEADDR, SOL_SOCKET, gethostname, socket
from ssl import PROTOCOL_TLS_SERVER, SSLContext, SSLSocket
from time import sleep, time
from time import perf_counter, sleep
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast

from .certificate_bundle import CertificateBundle
Expand All @@ -29,24 +29,26 @@
# collection of ports to avoid
# see: searchfox.org/mozilla-central/source/netwerk/base/nsIOService.cpp
# include ports above 1024
BLOCKED_PORTS = (
1719,
1720,
1723,
2049,
3659,
4045,
5060,
5061,
6000,
6566,
6665,
6666,
6667,
6668,
6669,
6697,
10080,
BLOCKED_PORTS = frozenset(
(
1719,
1720,
1723,
2049,
3659,
4045,
5060,
5061,
6000,
6566,
6665,
6666,
6667,
6668,
6669,
6697,
10080,
)
)
LOG = getLogger(__name__)

Expand All @@ -56,9 +58,9 @@ def create_listening_socket(
port: int = 0,
remote: bool = False,
) -> socket:
"""Create listening socket. Search for an open socket if needed and
configure the socket. If a specific port is unavailable or no
available ports can be found socket.error will be raised.
"""Create listening socket. Search for an open socket if needed and configure the
socket. If the specified port is unavailable an OSError or PermissionError will be
raised. If an available port cannot be found a RuntimeError will be raised.

Args:
attempts: Number of attempts to configure the socket.
Expand All @@ -71,13 +73,18 @@ def create_listening_socket(
assert attempts > 0
assert 0 <= port <= 65535

if port in BLOCKED_PORTS or 0 < port <= 1024:
raise ValueError("Cannot bind to blocked ports or ports <= 1024")

for remaining in reversed(range(attempts)):
sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
# attempt to bind/listen
try:
sock.bind(("0.0.0.0" if remote else "127.0.0.1", port))
sock.listen(5)
# put socket in non-blocking mode
sock.settimeout(0)
except (OSError, PermissionError) as exc:
sock.close()
if remaining > 0:
Expand Down Expand Up @@ -132,31 +139,36 @@ def __enter__(self) -> "Sapphire":
def __exit__(self, *exc: Any) -> None:
self.close()

def clear_backlog(self) -> None:
def clear_backlog(self, timeout: float = 10) -> bool:
"""Remove all pending connections from backlog. This should only be
called when there isn't anything actively trying to connect.

Args:
None
timeout: Maximum number of seconds to run.

Returns:
None
True if all connections are cleared from the backlog otherwise False.
"""
# this assumes the socket is in non-blocking mode
assert not self._socket.getblocking()
LOG.debug("clearing socket backlog")
self._socket.settimeout(0)
deadline = time() + 10
deadline = perf_counter() + timeout
while True:
try:
self._socket.accept()[0].close()
except BlockingIOError:
# no remaining pending connections
break
except OSError as exc:
LOG.debug("Error closing socket: %r", exc)
else:
LOG.debug("pending socket closed")
# if this fires something is likely actively trying to connect
assert deadline > time()
self._socket.settimeout(None)
if deadline <= perf_counter():
return False
# avoid hogging the cpu
sleep(0.1)
return True

def close(self) -> None:
"""Close listening server socket.
Expand Down
2 changes: 1 addition & 1 deletion sapphire/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_job_04(mocker, tmp_path):
assert resource.target == srv_include / "test_file.txt"
assert resource.url == request
# test redirect with file in a nested directory
request = "/".join(["testinc", "nested", "nested_file.txt"])
request = "testinc/nested/nested_file.txt"
resource = job.lookup_resource(request)
assert isinstance(resource, FileResource)
assert resource.target == nst_1
Expand Down
63 changes: 39 additions & 24 deletions sapphire/test_sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import socket
from hashlib import sha1
from itertools import repeat
from itertools import count, repeat
from os import urandom
from pathlib import Path
from platform import system
Expand Down Expand Up @@ -53,15 +53,15 @@ def create(cls, fname, path, data=b"Test!", calc_hash=False, url_prefix=None):
return test


@mark.parametrize("count", [1, 100])
def test_sapphire_01(client, tmp_path, count):
@mark.parametrize("files", [1, 100])
def test_sapphire_01(client, tmp_path, files):
"""test serving files"""
_TestFile.create("unrelated.bin", tmp_path)
to_serve = [
_TestFile.create(
f"test_{i:04d}.html", tmp_path, data=urandom(5), calc_hash=True
)
for i in range(count)
for i in range(files)
]
# all files are required
required = [x.file for x in to_serve]
Expand All @@ -81,22 +81,22 @@ def test_sapphire_01(client, tmp_path, count):


@mark.parametrize(
"count, req_idx",
"files, req_idx",
[
# multiple files (skip optional)
(5, 0),
# multiple files (serve optional)
(5, 4),
],
)
def test_sapphire_02(client, tmp_path, count, req_idx):
def test_sapphire_02(client, tmp_path, files, req_idx):
"""test serving files"""
_TestFile.create("unrelated.bin", tmp_path)
to_serve = [
_TestFile.create(
f"test_{i:04d}.html", tmp_path, data=urandom(5), calc_hash=True
)
for i in range(count)
for i in range(files)
]
required = to_serve[req_idx].file
with Sapphire(timeout=10) as serv:
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_sapphire_13(client, tmp_path, path, query):
with Sapphire(timeout=10) as serv:
# target will be requested indirectly via the redirect
target = _TestFile.create(path, tmp_path, data=b"Redirect DATA!")
request_path = "redirect" if query is None else "?".join(("redirect", query))
request_path = "redirect" if query is None else f"redirect?{query}"
redirect = _TestFile(request_path)
# point "redirect" at target
smap.set_redirect("redirect", target.file, required=True)
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_sapphire_15(client, tmp_path, query, required):
_data = b"dynamic response -- TEST DATA!"
# build request
path = "dyn_test"
request = path if query is None else "?".join([path, query])
request = path if query is None else f"{path}?{query}"

# setup custom callback
def dr_callback(data):
Expand Down Expand Up @@ -523,8 +523,7 @@ def test_sapphire_19(client_factory, tmp_path):
with Sapphire(max_workers=max_workers, timeout=60) as serv:
clients = []
try:
for _ in range(max_workers): # number of clients to spawn
clients.append(client_factory(rx_size=1))
clients = [client_factory(rx_size=1) for _ in range(max_workers)]
for client in clients:
client.launch(
"127.0.0.1", serv.port, to_serve, in_order=True, throttle=0.05
Expand Down Expand Up @@ -701,16 +700,24 @@ def test_sapphire_26(client, tmp_path):

def test_sapphire_27(mocker):
"""test Sapphire.clear_backlog()"""
mocker.patch("sapphire.core.socket", autospec=True)
mocker.patch("sapphire.core.time", autospec=True, return_value=1)
mocker.patch("sapphire.core.perf_counter", autospec=True, side_effect=count())
mocker.patch("sapphire.core.sleep", autospec=True)
# test clearing backlog
pending = mocker.Mock(spec_set=socket.socket)
pending.accept.side_effect = ((pending, None), OSError, BlockingIOError)
pending.getblocking.return_value = False
pending.getsockname.return_value = (None, 1337)
mocker.patch("sapphire.core.socket", return_value=pending)
with Sapphire(timeout=10) as serv:
serv._socket = mocker.Mock(spec_set=socket.socket)
serv._socket.accept.side_effect = ((pending, None), OSError, BlockingIOError)
serv.clear_backlog()
assert serv.clear_backlog()
assert serv._socket.accept.call_count == 3
assert serv._socket.settimeout.call_count == 2
assert pending.close.call_count == 1
assert pending.close.call_count == 1
pending.reset_mock()
# test hang
pending.accept.side_effect = None
pending.accept.return_value = (pending, None)
with Sapphire(timeout=1) as serv:
assert not serv.clear_backlog()


@mark.skipif(system() != "Windows", reason="Only supported on Windows")
Expand All @@ -731,12 +738,14 @@ def test_sapphire_28(client, tmp_path):
assert test.len_srv == test.len_org


def test_sapphire_29(tmp_path):
def test_sapphire_29():
"""test Sapphire with certificates"""
certs = CertificateBundle.create(path=tmp_path)
with Sapphire(timeout=10, certs=certs) as serv:
assert serv.scheme == "https"
certs.cleanup()
certs = CertificateBundle.create()
try:
with Sapphire(timeout=10, certs=certs) as serv:
assert serv.scheme == "https"
finally:
certs.cleanup()


@mark.parametrize(
Expand Down Expand Up @@ -776,14 +785,20 @@ def test_create_listening_socket_02(mocker, bind, attempts, raised):
mocker.patch("sapphire.core.sleep", autospec=True)
fake_sock = mocker.patch("sapphire.core.socket", autospec=True)
fake_sock.return_value.bind.side_effect = bind
with raises(raised):
with raises(raised, match="foo"):
create_listening_socket(attempts=attempts)
assert fake_sock.return_value.close.call_count == attempts


def test_create_listening_socket_03(mocker):
"""test create_listening_socket() - fail to find port"""
fake_sock = mocker.patch("sapphire.core.socket", autospec=True)
# specify blocked port
with raises(ValueError, match="Cannot bind to blocked ports"):
create_listening_socket(port=6000, attempts=1)
# specify reserved port
with raises(ValueError, match="Cannot bind to blocked ports"):
create_listening_socket(port=123, attempts=1)
# always choose a blocked port
fake_sock.return_value.getsockname.return_value = (None, 6665)
with raises(RuntimeError, match="Could not find available port"):
Expand Down
12 changes: 10 additions & 2 deletions sapphire/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from random import randint
from threading import Thread, ThreadError

from pytest import mark
from pytest import mark, raises

from .job import Job
from .worker import Request, Worker
Expand Down Expand Up @@ -45,6 +45,7 @@ def test_worker_01(mocker):
[
socket.timeout("test"),
OSError("test"),
BlockingIOError("test"),
],
)
def test_worker_02(mocker, exc):
Expand Down Expand Up @@ -211,7 +212,14 @@ def test_response_02(req):
assert Request.parse(req) is None


def test_response_03():
def test_response_03(mocker):
"""test Request.parse() fail to parse"""
mocker.patch("sapphire.worker.urlparse", side_effect=ValueError("foo"))
with raises(ValueError, match="foo"):
Request.parse(b"GET http://foo HTTP/1.1")


def test_response_04():
"""test Request.parse() by passing random urls"""
for _ in range(1000):
# create random 'netloc', for example '%1A%EF%09'
Expand Down
3 changes: 3 additions & 0 deletions sapphire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def launch(
job.accepting.clear()
w_thread.start()
return cls(conn, w_thread)
except BlockingIOError:
# accept() can block because of race between select() and accept()
pass
except OSError as exc:
LOG.debug("worker thread not launched: %s", exc)
except ThreadError:
Expand Down
Loading