Skip to content

Commit

Permalink
[sapphire] Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
tysmith committed Mar 19, 2024
1 parent 2c1d36c commit 0764dac
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 283 deletions.
4 changes: 2 additions & 2 deletions grizzly/common/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def test_runner_10(mocker, tmp_path):
inc3.write_bytes(b"a")
# build server map
smap = ServerMap()
smap.set_include("/", str(inc_path1))
smap.set_include("/test", str(inc_path2))
smap.set_include("/", inc_path1)
smap.set_include("/test", inc_path2)
with TestCase("a.b", "x") as test:
test.add_from_bytes(b"", test.entry_point)
serv_files = {
Expand Down
1 change: 1 addition & 0 deletions grizzly/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run(
runner.launch(location, max_retries=launch_attempts, retry_delay=0)
runner.post_launch(delay=post_launch_delay)
# TODO: avoid running test case if runner.startup_failure is True
# especially if it is a hang!

# create and populate a test case
current_test = self.generate_testcase()
Expand Down
7 changes: 4 additions & 3 deletions sapphire/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from logging import DEBUG, INFO, basicConfig
from pathlib import Path
from typing import List, Optional

from .core import Sapphire


def configure_logging(log_level):
def configure_logging(log_level: int) -> None:
if log_level == DEBUG:
date_fmt = None
log_fmt = "%(asctime)s %(levelname).1s %(name)s | %(message)s"
Expand All @@ -18,7 +19,7 @@ def configure_logging(log_level):
basicConfig(format=log_fmt, datefmt=date_fmt, level=log_level)


def parse_args(argv=None):
def parse_args(argv: Optional[List[str]] = None) -> Namespace:
# log levels for console logging
level_map = {"DEBUG": DEBUG, "INFO": INFO}
parser = ArgumentParser()
Expand Down
28 changes: 19 additions & 9 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from logging import getLogger
from socket import socket
from time import time
from traceback import format_exception
from typing import Any, Callable, List, Optional, Union

from .job import Job
from .worker import Worker

__author__ = "Tyson Smith"
Expand All @@ -27,24 +30,26 @@ class ConnectionManager:
"_socket",
)

def __init__(self, job, srv_socket, limit=1, poll=0.5):
def __init__(
self, job: Job, srv_socket: socket, limit: int = 1, poll: float = 0.5
) -> None:
assert limit > 0
assert poll > 0
self._deadline = None
self._deadline: Optional[float] = None
self._deadline_exceeded = False
self._job = job
self._limit = limit
self._next_poll = 0
self._next_poll = 0.0
self._poll = poll
self._socket = srv_socket

def __enter__(self):
def __enter__(self) -> "ConnectionManager":
return self

def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
self.close()

def _can_continue(self, continue_cb):
def _can_continue(self, continue_cb: Union[Callable[[], bool], None]) -> bool:
"""Check timeout and callback status.
Args:
Expand All @@ -68,7 +73,7 @@ def _can_continue(self, continue_cb):
return False
return True

def close(self):
def close(self) -> None:
"""Set job state to finished and raise any errors encountered by workers.
Args:
Expand All @@ -88,7 +93,7 @@ def close(self):
raise exc_obj

@staticmethod
def _join_workers(workers, timeout=0):
def _join_workers(workers: List[Worker], timeout: float = 0) -> List[Worker]:
"""Attempt to join workers.
Args:
Expand All @@ -106,7 +111,12 @@ def _join_workers(workers, timeout=0):
alive.append(worker)
return alive

def serve(self, timeout, continue_cb=None, shutdown_delay=SHUTDOWN_DELAY):
def serve(
self,
timeout: int,
continue_cb: Optional[Callable[[], bool]] = None,
shutdown_delay: float = SHUTDOWN_DELAY,
) -> bool:
"""Manage workers and serve job contents.
Args:
Expand Down
80 changes: 45 additions & 35 deletions sapphire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
"""
Sapphire HTTP server
"""
from argparse import Namespace
from logging import getLogger
from pathlib import Path
from socket import SO_REUSEADDR, SOL_SOCKET, gethostname, socket
from ssl import PROTOCOL_TLS_SERVER, SSLContext
from ssl import PROTOCOL_TLS_SERVER, SSLContext, SSLSocket
from time import sleep, time
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

from .connection_manager import ConnectionManager
from .job import Job, Served
from .server_map import ServerMap

__all__ = (
"BLOCKED_PORTS",
Expand Down Expand Up @@ -47,16 +50,21 @@
LOG = getLogger(__name__)


def create_listening_socket(attempts=10, port=0, remote=False, timeout=None):
def create_listening_socket(
attempts: int = 10,
port: int = 0,
remote: bool = False,
timeout: Optional[float] = None,
) -> socket:
"""Create listening socket. Search for an open socket if needed and
configure the socket. If a specific port is unavailable or no
available ports can be found socket.error will be raised.
Args:
attempts (int): Number of attempts to configure the socket.
port (int): Port to listen on. Use 0 for system assigned port.
remote (bool): Accept all (non-local) incoming connections.
timeout (float): Used to set socket timeout.
attempts: Number of attempts to configure the socket.
port: Port to listen on. Use 0 for system assigned port.
remote: Accept all (non-local) incoming connections.
timeout: Used to set socket timeout.
Returns:
socket: A listening socket.
Expand Down Expand Up @@ -100,13 +108,13 @@ class Sapphire:

def __init__(
self,
allow_remote=False,
auto_close=-1,
allow_remote: bool = False,
auto_close: int = -1,
certs=None,
max_workers=10,
port=0,
timeout=60,
):
max_workers: int = 10,
port: int = 0,
timeout: int = 60,
) -> None:
assert timeout >= 0
self._auto_close = auto_close # call 'window.close()' on 4xx error pages
self._max_workers = max_workers # limit worker threads
Expand All @@ -119,20 +127,22 @@ def __init__(
if certs:
context = SSLContext(PROTOCOL_TLS_SERVER)
context.load_cert_chain(certs.host, certs.key)
self._socket = context.wrap_socket(sock, server_side=True)
self._socket: Union[socket, SSLSocket] = context.wrap_socket(
sock, server_side=True
)
self.scheme = "https"
else:
self._socket = sock
self.scheme = "http"
self.timeout = timeout

def __enter__(self):
def __enter__(self) -> "Sapphire":
return self

def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
self.close()

def clear_backlog(self):
def clear_backlog(self) -> None:
"""Remove all pending connections from backlog. This should only be
called when there isn't anything actively trying to connect.
Expand All @@ -158,7 +168,7 @@ def clear_backlog(self):
assert deadline > time()
self._socket.settimeout(self.LISTEN_TIMEOUT)

def close(self):
def close(self) -> None:
"""Close listening server socket.
Args:
Expand All @@ -170,25 +180,25 @@ def close(self):
self._socket.close()

@property
def port(self):
def port(self) -> int:
"""Port number of listening socket.
Args:
None
Returns:
int: Listening port number.
Listening port number.
"""
return self._socket.getsockname()[1]
return int(self._socket.getsockname()[1])

def serve_path(
self,
path,
continue_cb=None,
forever=False,
required_files=None,
server_map=None,
):
path: Path,
continue_cb: Optional[Callable[[], bool]] = None,
forever: bool = False,
required_files: Optional[Iterable[str]] = None,
server_map: Optional[ServerMap] = None,
) -> Tuple[Served, Dict[str, Path]]:
"""Serve files in path.
The status codes include:
Expand All @@ -197,17 +207,17 @@ def serve_path(
- Served.REQUEST: Some files were requested
Args:
path (Path): Directory to use as wwwroot.
continue_cb (callable): A callback that can be used to exit the serve loop.
This must be a callable that returns a bool.
forever (bool): Continue to handle requests even after all files have
been served. This is meant to be used with continue_cb.
required_files (list(str)): Files that need to be served in order to exit
the serve loop.
path: Directory to use as wwwroot.
continue_cb: A callback that can be used to exit the serve loop.
This must be a callable that returns a bool.
forever: Continue to handle requests even after all files have been served.
This is meant to be used with continue_cb.
required_files: Files that need to be served in order to exit the
serve loop.
server_map (ServerMap):
Returns:
tuple(int, dict[str, Path]): Status code and files served.
Status code and files served.
"""
assert isinstance(path, Path)
assert self.timeout >= 0
Expand All @@ -225,7 +235,7 @@ def serve_path(
return (Served.TIMEOUT if timed_out else job.status, job.served)

@classmethod
def main(cls, args):
def main(cls, args: Namespace) -> None:
try:
with cls(
allow_remote=args.remote, port=args.port, timeout=args.timeout
Expand Down
Loading

0 comments on commit 0764dac

Please sign in to comment.