From 0764dacca05be696e97d43ccc11a418e22d6aac8 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Mon, 18 Mar 2024 11:37:15 -0700 Subject: [PATCH] [sapphire] Add type hints --- grizzly/common/test_runner.py | 4 +- grizzly/session.py | 1 + sapphire/__main__.py | 7 +- sapphire/connection_manager.py | 28 ++++-- sapphire/core.py | 80 +++++++++-------- sapphire/job.py | 156 ++++++++++++++------------------- sapphire/server_map.py | 89 +++++++++++-------- sapphire/test_job.py | 89 +++++++++---------- sapphire/test_sapphire.py | 29 +++--- sapphire/test_server_map.py | 35 +++++--- sapphire/worker.py | 62 ++++++------- 11 files changed, 297 insertions(+), 283 deletions(-) diff --git a/grizzly/common/test_runner.py b/grizzly/common/test_runner.py index 0cbfaa55..f2465e72 100644 --- a/grizzly/common/test_runner.py +++ b/grizzly/common/test_runner.py @@ -335,8 +335,8 @@ def test_runner_10(mocker, tmp_path): inc3.write_bytes(b"a") # build server map smap = ServerMap() - smap.set_include("/", str(inc_path1)) - smap.set_include("/test", str(inc_path2)) + smap.set_include("/", inc_path1) + smap.set_include("/test", inc_path2) with TestCase("a.b", "x") as test: test.add_from_bytes(b"", test.entry_point) serv_files = { diff --git a/grizzly/session.py b/grizzly/session.py index 826eaff8..24db3a65 100644 --- a/grizzly/session.py +++ b/grizzly/session.py @@ -202,6 +202,7 @@ def run( runner.launch(location, max_retries=launch_attempts, retry_delay=0) runner.post_launch(delay=post_launch_delay) # TODO: avoid running test case if runner.startup_failure is True + # especially if it is a hang! # create and populate a test case current_test = self.generate_testcase() diff --git a/sapphire/__main__.py b/sapphire/__main__.py index 81999500..a98a442b 100644 --- a/sapphire/__main__.py +++ b/sapphire/__main__.py @@ -1,14 +1,15 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from logging import DEBUG, INFO, basicConfig from pathlib import Path +from typing import List, Optional from .core import Sapphire -def configure_logging(log_level): +def configure_logging(log_level: int) -> None: if log_level == DEBUG: date_fmt = None log_fmt = "%(asctime)s %(levelname).1s %(name)s | %(message)s" @@ -18,7 +19,7 @@ def configure_logging(log_level): basicConfig(format=log_fmt, datefmt=date_fmt, level=log_level) -def parse_args(argv=None): +def parse_args(argv: Optional[List[str]] = None) -> Namespace: # log levels for console logging level_map = {"DEBUG": DEBUG, "INFO": INFO} parser = ArgumentParser() diff --git a/sapphire/connection_manager.py b/sapphire/connection_manager.py index 56cc9f60..de722762 100644 --- a/sapphire/connection_manager.py +++ b/sapphire/connection_manager.py @@ -2,9 +2,12 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from logging import getLogger +from socket import socket from time import time from traceback import format_exception +from typing import Any, Callable, List, Optional, Union +from .job import Job from .worker import Worker __author__ = "Tyson Smith" @@ -27,24 +30,26 @@ class ConnectionManager: "_socket", ) - def __init__(self, job, srv_socket, limit=1, poll=0.5): + def __init__( + self, job: Job, srv_socket: socket, limit: int = 1, poll: float = 0.5 + ) -> None: assert limit > 0 assert poll > 0 - self._deadline = None + self._deadline: Optional[float] = None self._deadline_exceeded = False self._job = job self._limit = limit - self._next_poll = 0 + self._next_poll = 0.0 self._poll = poll self._socket = srv_socket - def __enter__(self): + def __enter__(self) -> "ConnectionManager": return self - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: self.close() - def _can_continue(self, continue_cb): + def _can_continue(self, continue_cb: Union[Callable[[], bool], None]) -> bool: """Check timeout and callback status. Args: @@ -68,7 +73,7 @@ def _can_continue(self, continue_cb): return False return True - def close(self): + def close(self) -> None: """Set job state to finished and raise any errors encountered by workers. Args: @@ -88,7 +93,7 @@ def close(self): raise exc_obj @staticmethod - def _join_workers(workers, timeout=0): + def _join_workers(workers: List[Worker], timeout: float = 0) -> List[Worker]: """Attempt to join workers. Args: @@ -106,7 +111,12 @@ def _join_workers(workers, timeout=0): alive.append(worker) return alive - def serve(self, timeout, continue_cb=None, shutdown_delay=SHUTDOWN_DELAY): + def serve( + self, + timeout: int, + continue_cb: Optional[Callable[[], bool]] = None, + shutdown_delay: float = SHUTDOWN_DELAY, + ) -> bool: """Manage workers and serve job contents. Args: diff --git a/sapphire/core.py b/sapphire/core.py index 31bd8348..7579b42c 100644 --- a/sapphire/core.py +++ b/sapphire/core.py @@ -4,14 +4,17 @@ """ Sapphire HTTP server """ +from argparse import Namespace from logging import getLogger from pathlib import Path from socket import SO_REUSEADDR, SOL_SOCKET, gethostname, socket -from ssl import PROTOCOL_TLS_SERVER, SSLContext +from ssl import PROTOCOL_TLS_SERVER, SSLContext, SSLSocket from time import sleep, time +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union from .connection_manager import ConnectionManager from .job import Job, Served +from .server_map import ServerMap __all__ = ( "BLOCKED_PORTS", @@ -47,16 +50,21 @@ LOG = getLogger(__name__) -def create_listening_socket(attempts=10, port=0, remote=False, timeout=None): +def create_listening_socket( + attempts: int = 10, + port: int = 0, + remote: bool = False, + timeout: Optional[float] = None, +) -> socket: """Create listening socket. Search for an open socket if needed and configure the socket. If a specific port is unavailable or no available ports can be found socket.error will be raised. Args: - attempts (int): Number of attempts to configure the socket. - port (int): Port to listen on. Use 0 for system assigned port. - remote (bool): Accept all (non-local) incoming connections. - timeout (float): Used to set socket timeout. + attempts: Number of attempts to configure the socket. + port: Port to listen on. Use 0 for system assigned port. + remote: Accept all (non-local) incoming connections. + timeout: Used to set socket timeout. Returns: socket: A listening socket. @@ -100,13 +108,13 @@ class Sapphire: def __init__( self, - allow_remote=False, - auto_close=-1, + allow_remote: bool = False, + auto_close: int = -1, certs=None, - max_workers=10, - port=0, - timeout=60, - ): + max_workers: int = 10, + port: int = 0, + timeout: int = 60, + ) -> None: assert timeout >= 0 self._auto_close = auto_close # call 'window.close()' on 4xx error pages self._max_workers = max_workers # limit worker threads @@ -119,20 +127,22 @@ def __init__( if certs: context = SSLContext(PROTOCOL_TLS_SERVER) context.load_cert_chain(certs.host, certs.key) - self._socket = context.wrap_socket(sock, server_side=True) + self._socket: Union[socket, SSLSocket] = context.wrap_socket( + sock, server_side=True + ) self.scheme = "https" else: self._socket = sock self.scheme = "http" self.timeout = timeout - def __enter__(self): + def __enter__(self) -> "Sapphire": return self - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: self.close() - def clear_backlog(self): + def clear_backlog(self) -> None: """Remove all pending connections from backlog. This should only be called when there isn't anything actively trying to connect. @@ -158,7 +168,7 @@ def clear_backlog(self): assert deadline > time() self._socket.settimeout(self.LISTEN_TIMEOUT) - def close(self): + def close(self) -> None: """Close listening server socket. Args: @@ -170,25 +180,25 @@ def close(self): self._socket.close() @property - def port(self): + def port(self) -> int: """Port number of listening socket. Args: None Returns: - int: Listening port number. + Listening port number. """ - return self._socket.getsockname()[1] + return int(self._socket.getsockname()[1]) def serve_path( self, - path, - continue_cb=None, - forever=False, - required_files=None, - server_map=None, - ): + path: Path, + continue_cb: Optional[Callable[[], bool]] = None, + forever: bool = False, + required_files: Optional[Iterable[str]] = None, + server_map: Optional[ServerMap] = None, + ) -> Tuple[Served, Dict[str, Path]]: """Serve files in path. The status codes include: @@ -197,17 +207,17 @@ def serve_path( - Served.REQUEST: Some files were requested Args: - path (Path): Directory to use as wwwroot. - continue_cb (callable): A callback that can be used to exit the serve loop. - This must be a callable that returns a bool. - forever (bool): Continue to handle requests even after all files have - been served. This is meant to be used with continue_cb. - required_files (list(str)): Files that need to be served in order to exit - the serve loop. + path: Directory to use as wwwroot. + continue_cb: A callback that can be used to exit the serve loop. + This must be a callable that returns a bool. + forever: Continue to handle requests even after all files have been served. + This is meant to be used with continue_cb. + required_files: Files that need to be served in order to exit the + serve loop. server_map (ServerMap): Returns: - tuple(int, dict[str, Path]): Status code and files served. + Status code and files served. """ assert isinstance(path, Path) assert self.timeout >= 0 @@ -225,7 +235,7 @@ def serve_path( return (Served.TIMEOUT if timed_out else job.status, job.served) @classmethod - def main(cls, args): + def main(cls, args: Namespace) -> None: try: with cls( allow_remote=args.remote, port=args.port, timeout=args.timeout diff --git a/sapphire/job.py b/sapphire/job.py index 37f9dfa7..ac1c1790 100644 --- a/sapphire/job.py +++ b/sapphire/job.py @@ -4,17 +4,18 @@ """ Sapphire HTTP server job """ -from collections import namedtuple from enum import Enum, unique from errno import ENAMETOOLONG +from itertools import chain from logging import getLogger from mimetypes import guess_type from os.path import splitext from pathlib import Path from queue import Queue from threading import Event, Lock +from typing import Any, Dict, Iterable, NamedTuple, Optional, Set, Tuple, Union, cast -from .server_map import Resource +from .server_map import DynamicResource, FileResource, RedirectResource, ServerMap __author__ = "Tyson Smith" __credits__ = ["Tyson Smith"] @@ -36,7 +37,14 @@ class Served(Enum): TIMEOUT = 3 -Tracker = namedtuple("Tracker", "files lock") +class PendingTracker(NamedTuple): + files: Set[str] + lock: Lock + + +class ServedTracker(NamedTuple): + files: Set[FileResource] + lock: Lock class Job: @@ -65,20 +73,20 @@ class Job: def __init__( self, - wwwroot, - auto_close=-1, - forever=False, - required_files=None, - server_map=None, - ): + wwwroot: Path, + auto_close: int = -1, + forever: bool = False, + required_files: Optional[Iterable[str]] = None, + server_map: Optional[ServerMap] = None, + ) -> None: self._complete = Event() - self._pending = Tracker(files=set(), lock=Lock()) - self._served = Tracker(files=set(), lock=Lock()) + self._pending = PendingTracker(files=set(), lock=Lock()) + self._served = ServedTracker(files=set(), lock=Lock()) self._wwwroot = wwwroot.resolve() self.accepting = Event() self.accepting.set() self.auto_close = auto_close - self.exceptions = Queue() + self.exceptions: Queue[Tuple[Any, Any, Any]] = Queue() self.forever = forever self.server_map = server_map self.worker_complete = Event() @@ -86,12 +94,12 @@ def __init__( if not self._pending.files and not self.forever: raise RuntimeError("Empty Job") - def _build_pending(self, required_files): + def _build_pending(self, required_files: Union[Iterable[str], None]) -> None: """Build file list to track files that must be served. Note: This is intended to only be called once by __init__(). Args: - required_files (list(str)): List of file paths relative to wwwroot. + required_files: File paths (relative to wwwroot) that must be served. Returns: None @@ -110,26 +118,27 @@ def _build_pending(self, required_files): if not self._pending.files and not self._wwwroot.is_dir(): raise OSError(f"wwwroot '{self._wwwroot}' does not exist") if self.server_map: - for redirect, resource in self.server_map.redirect.items(): - if resource.required: - self._pending.files.add(redirect) - LOG.debug("required: %r -> %r", redirect, resource.target) - for dyn_resp, resource in self.server_map.dynamic.items(): + for url, resource in cast( + Iterable[Tuple[str, Union[DynamicResource, RedirectResource]]], + chain( + self.server_map.redirect.items(), self.server_map.dynamic.items() + ), + ): if resource.required: - self._pending.files.add(dyn_resp) - LOG.debug("required: %r -> %r", dyn_resp, resource.target) + self._pending.files.add(url) + LOG.debug("required: %r -> %r", url, resource.target) LOG.debug("job has %d required file(s)", len(self._pending.files)) @classmethod - def lookup_mime(cls, url): + def lookup_mime(cls, url: str) -> str: """Determine mime type for a given URL. Args: - url (str): URL to inspect. + url: URL to inspect. Returns: - str: Mime type of URL or 'application/octet-stream' if the mime type - cannot be determined. + Mime type of URL. 'application/octet-stream' is returned if the mime + type cannot be determined. """ mime = cls.MIME_MAP.get(splitext(url)[-1].lower()) if mime is None: @@ -137,11 +146,13 @@ def lookup_mime(cls, url): mime = guess_type(url)[0] or "application/octet-stream" return mime - def lookup_resource(self, path): + def lookup_resource( + self, path: str + ) -> Optional[Union[FileResource, DynamicResource, RedirectResource]]: """Find the Resource mapped to a given URL path. Args: - path (str): URL path. + path: URL path. Returns: Resource: Resource for the given URL path or None if one is not found. @@ -152,15 +163,10 @@ def lookup_resource(self, path): local = self._wwwroot / path if local.is_file(): local = local.resolve() - with self._pending.lock: - required = str(local) in self._pending.files - return Resource( - Resource.URL_FILE, - local, - mime=self.lookup_mime(path), - required=required, - url=path, - ) + if self._wwwroot in local.parents: + with self._pending.lock: + required = str(local) in self._pending.files + return FileResource(path, required, local, self.lookup_mime(path)) except OSError as exc: if exc.errno == ENAMETOOLONG: # file name is too long to look up so ignore it @@ -177,20 +183,20 @@ def lookup_resource(self, path): LOG.debug("checking include %r", inc) # strip include prefix from potential file name file = path[len(inc) :].lstrip("/") - local = Path(self.server_map.include[inc].target) / file - if not local.is_file(): - continue - # file exists, look up resource - return Resource( - Resource.URL_INCLUDE, - local.resolve(), - mime=self.server_map.include[inc].mime or self.lookup_mime(file), - required=self.server_map.include[inc].required, - url=f"{inc}/{file}" if inc else file, - ) + local = self.server_map.include[inc].target / file + # check that the file exists within the include path + if local.is_file(): + local = local.resolve() + if self.server_map.include[inc].target in local.parents: + return FileResource( + f"{inc}/{file}" if inc else file, + self.server_map.include[inc].required, + local, + self.lookup_mime(file), + ) return None - def finish(self): + def finish(self) -> None: """Mark Job as complete. Args: @@ -201,63 +207,34 @@ def finish(self): """ self._complete.set() - def mark_served(self, item): + def mark_served(self, item: FileResource) -> None: """Mark a Resource as served to track served Resources. Args: - item (Resource): Resource to track. + item: Resource to track. Returns: None """ - assert isinstance(item, Resource) - assert item.type in (Resource.URL_FILE, Resource.URL_INCLUDE) with self._served.lock: if item.url not in self._served.files: self._served.files.add(item) - def is_complete(self, wait=None): + def is_complete(self, wait: Optional[float] = None) -> bool: """Check if a Job has been marked as complete. Args: - wait (float): Time to wait in seconds. + wait: Time to wait in seconds. Returns: - boot: True if Job complete flag is set otherwise False. + True if Job complete flag is set otherwise False. """ if wait is not None: return self._complete.wait(wait) return self._complete.is_set() - def is_forbidden(self, target, is_include=False): - """Check if a path is forbidden. Anything outside of wwwroot and not - added by an include is forbidden. - - Note: It is assumed that the files exist on disk and that the - paths are absolute and sanitized. - - Args: - target (Path or str): Path to check. - is_include (bool): Indicates if given path is an include. - - Returns: - bool: True if no forbidden otherwise False. - """ - target = str(target) - if not is_include: - # check if target is in wwwroot - if target.startswith(str(self._wwwroot)): - return False - elif self.server_map: - # check if target is in an included path - for resource in self.server_map.include.values(): - if target.startswith(resource.target): - # target is in a valid include path - return False - return True - @property - def pending(self): + def pending(self) -> int: """Number of pending files. Args: @@ -269,7 +246,7 @@ def pending(self): with self._pending.lock: return len(self._pending.files) - def remove_pending(self, file_name): + def remove_pending(self, file_name: str) -> bool: """Remove a file from pending list. Args: @@ -279,32 +256,31 @@ def remove_pending(self, file_name): bool: True when all files have been removed otherwise False. """ with self._pending.lock: - if self._pending.files: - self._pending.files.discard(file_name) + self._pending.files.discard(file_name) return not self._pending.files @property - def served(self): + def served(self) -> Dict[str, Path]: """Served files. Args: None Returns: - dict: Mapping of URLs to files on disk. + Mapping of URLs to files on disk. """ with self._served.lock: return {entry.url: entry.target for entry in self._served.files} @property - def status(self): + def status(self) -> Served: """Job Status. Args: None Returns: - Served: Current status. + Current status. """ with self._pending.lock: if not self._served.files: diff --git a/sapphire/server_map.py b/sapphire/server_map.py index d5dc3a77..42abf8ca 100644 --- a/sapphire/server_map.py +++ b/sapphire/server_map.py @@ -2,10 +2,12 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from dataclasses import dataclass from inspect import signature from logging import getLogger -from os.path import abspath, isdir, relpath +from pathlib import Path from re import search as re_search +from typing import Callable, Dict __all__ = ("Resource", "ServerMap") __author__ = "Tyson Smith" @@ -22,32 +24,46 @@ class MapCollisionError(Exception): """Raised when a URL is already in use by ServerMap""" +@dataclass(frozen=True, eq=False) class Resource: - URL_DYNAMIC = 0 - URL_FILE = 1 - URL_INCLUDE = 2 - URL_REDIRECT = 3 + url: str + required: bool - __slots__ = ("mime", "required", "target", "type", "url") - def __init__(self, resource_type, target, mime=None, required=False, url=None): - self.mime = mime - self.required = required - self.target = target - self.type = resource_type - self.url = url +@dataclass(frozen=True, eq=False) +class DynamicResource(Resource): + target: Callable[[str], bytes] + mime: str + + +@dataclass(frozen=True, eq=False) +class FileResource(Resource): + target: Path + mime: str + + +@dataclass(frozen=True, eq=False) +class IncludeResource(Resource): + target: Path + + +@dataclass(frozen=True, eq=False) +class RedirectResource(Resource): + target: str class ServerMap: __slots__ = ("dynamic", "include", "redirect") - def __init__(self): - self.dynamic = {} - self.include = {} # mapping of directories that can be requested - self.redirect = {} # document paths to map to file names using 307s + def __init__(self) -> None: + self.dynamic: Dict[str, DynamicResource] = {} + # mapping of directories that can be requested + self.include: Dict[str, IncludeResource] = {} + # document paths to map to file names using 307s + self.redirect: Dict[str, RedirectResource] = {} @staticmethod - def _check_url(url): + def _check_url(url: str) -> str: # check and sanitize URL url = url.strip("/") if re_search(r"\W", url) is not None: @@ -55,8 +71,12 @@ def _check_url(url): return url def set_dynamic_response( - self, url, callback, mime_type="application/octet-stream", required=False - ): + self, + url: str, + callback: Callable[[str], bytes], + mime_type: str = "application/octet-stream", + required: bool = False, + ) -> None: url = self._check_url(url) if not callable(callback): raise TypeError("callback must be callable") @@ -67,17 +87,14 @@ def set_dynamic_response( if url in self.include or url in self.redirect: raise MapCollisionError(f"URL collision on {url!r}") LOG.debug("mapping dynamic response %r -> %r (%r)", url, callback, mime_type) - self.dynamic[url] = Resource( - Resource.URL_DYNAMIC, callback, mime=mime_type, required=required - ) + self.dynamic[url] = DynamicResource(url, required, callback, mime_type) - def set_include(self, url, target_path): + def set_include(self, url: str, target: Path) -> None: url = self._check_url(url) - if not isdir(target_path): - raise OSError(f"Include path not found: {target_path}") + if not target.is_dir(): + raise OSError(f"Include path not found: {target}") if url in self.dynamic or url in self.redirect: raise MapCollisionError(f"URL collision on {url!r}") - target_path = abspath(target_path) # sanity check to prevent mapping overlapping paths # Note: This was added to help map file served via includes back to # the files on disk. This is a temporary workaround until mapping of @@ -86,20 +103,20 @@ def set_include(self, url, target_path): if url == existing_url: # allow overwriting entry continue - if not relpath(target_path, resource.target).startswith(".."): - LOG.error("%r mapping includes path %r", existing_url, target_path) + if resource.target in target.parents: + LOG.error("%r mapping includes path '%s'", existing_url, target) raise MapCollisionError( - f"{url!r} and {existing_url!r} include {target_path!r}" + f"{url!r} and {existing_url!r} include '{target}'" ) - if not relpath(resource.target, target_path).startswith(".."): - LOG.error("%r mapping includes path %r", url, resource.target) + if target in resource.target.parents: + LOG.error("%r mapping includes path '%s'", url, resource.target) raise MapCollisionError( - f"{url!r} and {existing_url!r} include {resource.target!r}" + f"{url!r} and {existing_url!r} include '{resource.target}'" ) - LOG.debug("mapping include %r -> %r", url, target_path) - self.include[url] = Resource(Resource.URL_INCLUDE, target_path) + LOG.debug("mapping include %r -> '%s'", url, target) + self.include[url] = IncludeResource(url, False, target) - def set_redirect(self, url, target, required=True): + def set_redirect(self, url: str, target: str, required: bool = True) -> None: url = self._check_url(url) if not isinstance(target, str): raise TypeError("target must be of type 'str'") @@ -107,4 +124,4 @@ def set_redirect(self, url, target, required=True): raise TypeError("target must not be an empty string") if url in self.dynamic or url in self.include: raise MapCollisionError(f"URL collision on {url!r}") - self.redirect[url] = Resource(Resource.URL_REDIRECT, target, required=required) + self.redirect[url] = RedirectResource(url, required, target) diff --git a/sapphire/test_job.py b/sapphire/test_job.py index b2d1dd85..a4158c48 100644 --- a/sapphire/test_job.py +++ b/sapphire/test_job.py @@ -8,15 +8,21 @@ from pytest import mark, raises from .job import Job, Served -from .server_map import Resource, ServerMap +from .server_map import ( + DynamicResource, + FileResource, + IncludeResource, + RedirectResource, + ServerMap, +) def test_job_01(tmp_path): """test creating a simple Job""" - test_file = tmp_path / "test.txt" - test_file.touch() with raises(RuntimeError, match="Empty Job"): Job(tmp_path) + test_file = tmp_path / "test.txt" + test_file.touch() job = Job(tmp_path, required_files=[test_file.name]) assert not job.forever assert job.status == Served.NONE @@ -25,11 +31,11 @@ def test_job_01(tmp_path): assert job.lookup_resource("test/test/") is None assert job.lookup_resource("test/../../") is None assert job.lookup_resource("\x00\x0B\xAD\xF0\x0D") is None - assert not job.is_forbidden(tmp_path) - assert not job.is_forbidden(tmp_path / "missing_file") + assert job.lookup_resource("test.txt") assert job.pending == 1 assert not job.is_complete() assert job.remove_pending(str(test_file)) + assert job.pending == 0 job.finish() assert not any(job.served) assert job.is_complete(wait=0.01) @@ -52,11 +58,10 @@ def test_job_02(tmp_path): assert job.status == Served.NONE assert not job.is_complete() resource = job.lookup_resource("req_file_1.txt") + assert isinstance(resource, FileResource) assert resource.required assert job.pending == 2 assert resource.target == tmp_path / "req_file_1.txt" - assert resource.type == Resource.URL_FILE - assert not job.is_forbidden(req[0]) assert not job.remove_pending("no_file.test") assert job.pending == 2 assert not job.remove_pending(str(req[0])) @@ -74,16 +79,16 @@ def test_job_02(tmp_path): job.mark_served(resource) assert len(job._served.files) == 2 resource = job.lookup_resource("opt_file_1.txt") + assert isinstance(resource, FileResource) assert not resource.required assert resource.target == opt[0] - assert resource.type == Resource.URL_FILE assert job.remove_pending(str(opt[0])) job.mark_served(resource) assert len(job._served.files) == 3 assert len(job.served) == 3 resource = job.lookup_resource("nested/opt_file_2.txt") + assert isinstance(resource, FileResource) assert resource.target == opt[1] - assert resource.type == Resource.URL_FILE assert not resource.required job.finish() assert job.is_complete() @@ -97,10 +102,9 @@ def test_job_03(tmp_path): job = Job(tmp_path, server_map=smap) assert job.status == Served.NONE resource = job.lookup_resource("one") - assert resource.type == Resource.URL_REDIRECT + assert isinstance(resource, RedirectResource) resource = job.lookup_resource("two") - assert resource is not None - assert resource.type == Resource.URL_REDIRECT + assert isinstance(resource, RedirectResource) assert job.pending == 1 assert job.remove_pending("two") assert job.pending == 0 @@ -128,12 +132,14 @@ def test_job_04(mocker, tmp_path): # restrictive to allow testing of some functionality mocker.patch.object(ServerMap, "_check_url", side_effect=lambda x: x) smap = ServerMap() - smap.set_include("testinc", str(srv_include)) + smap.set_include("testinc", srv_include) # add manually to avoid sanity checks in ServerMap.set_include() - smap.include["testinc/fakedir"] = Resource(Resource.URL_INCLUDE, srv_include) - smap.include["testinc/1/2/3"] = Resource(Resource.URL_INCLUDE, srv_include) - smap.include[""] = Resource(Resource.URL_INCLUDE, srv_include) - smap.set_include("testinc/inc2", str(srv_include_2)) + smap.include["testinc/fakedir"] = IncludeResource( + "testinc/fakedir", False, srv_include + ) + smap.include["testinc/1/2/3"] = IncludeResource("testinc/1/2/3", False, srv_include) + smap.include[""] = IncludeResource("", False, srv_include) + smap.set_include("testinc/inc2", srv_include_2) job = Job(srv_root, server_map=smap, required_files=[test_1.name]) assert job.status == Served.NONE # test include path pointing to a missing file @@ -146,33 +152,27 @@ def test_job_04(mocker, tmp_path): continue request = "/".join([incl, "test_file.txt"]) resource = job.lookup_resource(request) - assert resource.type == Resource.URL_INCLUDE + assert isinstance(resource, FileResource) assert resource.target == inc_1 assert resource.url == request.lstrip("/") # test nested include path pointing to a different include request = "testinc/inc2/test_file_2.txt" resource = job.lookup_resource(request) - assert resource.type == Resource.URL_INCLUDE + assert isinstance(resource, FileResource) assert resource.target == inc_2 assert resource.url == request # test redirect root without leading '/' request = "test_file.txt" resource = job.lookup_resource(request) - assert resource.type == Resource.URL_INCLUDE + assert isinstance(resource, FileResource) assert resource.target == srv_include / "test_file.txt" assert resource.url == request # test redirect with file in a nested directory request = "/".join(["testinc", "nested", "nested_file.txt"]) resource = job.lookup_resource(request) - assert resource.type == Resource.URL_INCLUDE + assert isinstance(resource, FileResource) assert resource.target == nst_1 assert resource.url == request - assert not job.is_forbidden( - (srv_root / ".." / "test" / "test_file.txt").resolve(), is_include=True - ) - assert not job.is_forbidden( - (srv_include / ".." / "root" / "req_file.txt").resolve(), is_include=False - ) def test_job_05(tmp_path): @@ -190,10 +190,10 @@ def test_job_05(tmp_path): inc_file2.write_bytes(b"a") # test url matching part of the file name smap = ServerMap() - smap.include["inc_url"] = Resource(Resource.URL_INCLUDE, str(inc_dir)) + smap.include["inc_url"] = IncludeResource("inc_url", False, inc_dir) job = Job(srv_root, server_map=smap, required_files=[req.name]) resource = job.lookup_resource("inc_url/sub/include.js") - assert resource.type == Resource.URL_INCLUDE + assert isinstance(resource, FileResource) assert resource.target == inc_file1 # test checking only the include url assert job.lookup_resource("inc_url") is None @@ -203,9 +203,9 @@ def test_job_05(tmp_path): inc_a.write_bytes(b"a") file_a = srv_root / "a.bin" file_a.write_bytes(b"a") - smap.include["/"] = Resource(Resource.URL_INCLUDE, str(inc_dir)) + smap.include["/"] = IncludeResource("", False, inc_dir) resource = job.lookup_resource("a.bin") - assert resource.type == Resource.URL_FILE + assert isinstance(resource, FileResource) assert resource.target == file_a # TODO: inc and inc subdir collision can fail. # /inc/a file @@ -222,12 +222,11 @@ def test_job_06(tmp_path): assert job.status == Served.NONE assert job.pending == 1 resource = job.lookup_resource("cb1") - assert resource.type == Resource.URL_DYNAMIC + assert isinstance(resource, DynamicResource) assert callable(resource.target) assert isinstance(resource.mime, str) resource = job.lookup_resource("cb2") - assert resource is not None - assert resource.type == Resource.URL_DYNAMIC + assert isinstance(resource, DynamicResource) assert callable(resource.target) assert isinstance(resource.mime, str) @@ -243,11 +242,7 @@ def test_job_07(tmp_path): job = Job(srv_root, required_files=[test_file.name]) assert job.status == Served.NONE assert job.pending == 1 - resource = job.lookup_resource("../no_access.txt") - assert resource.target == no_access - assert resource.type == Resource.URL_FILE - assert not job.is_forbidden(test_file) - assert job.is_forbidden((srv_root / ".." / "no_access.txt").resolve()) + assert job.lookup_resource("../no_access.txt") is None @mark.skipif(system() == "Windows", reason="Unsupported on Windows") @@ -283,34 +278,28 @@ def test_job_11(tmp_path): job = Job(tmp_path, required_files=["test.txt"]) assert not any(job.served) # add first resource - resource = Resource(Resource.URL_FILE, tmp_path / "a.bin", url="a.bin") + resource = FileResource("a.bin", False, tmp_path / "a.bin", "mine/mime") job.mark_served(resource) assert "a.bin" in job.served assert job.served[resource.url] == resource.target assert len(job.served) == 1 # add a resource with the same url - job.mark_served(Resource(Resource.URL_FILE, tmp_path / "a.bin", url="a.bin")) + job.mark_served(FileResource("a.bin", False, tmp_path / "a.bin", "a/a")) assert len(job.served) == 1 # add a nested resource - resource = Resource( - Resource.URL_FILE, tmp_path / "nested" / "b.bin", url="nested/b.bin" - ) + resource = FileResource("nested/b.bin", False, tmp_path / "nested" / "b.bin", "a/a") job.mark_served(resource) assert "nested/b.bin" in job.served assert job.served[resource.url] == resource.target assert len(job.served) == 2 # add an include resource - resource = Resource( - Resource.URL_INCLUDE, Path("/some/include/path/inc.bin"), url="inc.bin" - ) + resource = IncludeResource("inc.bin", False, Path("/some/include/path/inc.bin")) job.mark_served(resource) assert "inc.bin" in job.served assert job.served[resource.url] == resource.target assert len(job.served) == 3 # add an include resource pointing to a common file with unique url - resource = Resource( - Resource.URL_INCLUDE, Path("/some/include/path/inc.bin"), url="alt_path" - ) + resource = IncludeResource("alt_path", False, Path("/some/include/path/inc.bin")) job.mark_served(resource) assert "alt_path" in job.served assert len(job.served) == 4 diff --git a/sapphire/test_sapphire.py b/sapphire/test_sapphire.py index 092b411c..b154bc09 100644 --- a/sapphire/test_sapphire.py +++ b/sapphire/test_sapphire.py @@ -25,13 +25,11 @@ class _TestFile: def __init__(self, url, url_prefix=None): + assert isinstance(url, str) self.code = None self.content_type = None self.custom_request = None - if url_prefix: - self.file = "".join((url_prefix, url)) - else: - self.file = url + self.file = f"{url_prefix}{url}" if url_prefix else url self.len_org = 0 # original file length self.len_srv = 0 # served file length self.lock = Lock() @@ -45,6 +43,7 @@ def __init__(self, url, url_prefix=None): def _create_test(fname, path, data=b"Test!", calc_hash=False, url_prefix=None): + assert isinstance(path, Path) test = _TestFile(fname, url_prefix=url_prefix) with (path / fname).open("w+b") as out_fp: out_fp.write(data) @@ -117,13 +116,13 @@ def test_sapphire_03(client, tmp_path): root_dir.mkdir() invalid = Path(__file__) to_serve = [ - # missing file + # 0 - missing file _TestFile("does_not_exist.html"), - # add invalid file + # 1 - add invalid file _TestFile(str(invalid.resolve())), - # add file in parent of root_dir + # 2 - add file in parent of root_dir _create_test("no_access.html", tmp_path, data=b"no_access", url_prefix="../"), - # add valid test + # 3 - add valid test _create_test("test_case.html", root_dir), ] required = [to_serve[-1].file] @@ -138,7 +137,7 @@ def test_sapphire_03(client, tmp_path): assert invalid.name in to_serve[1].file assert to_serve[1].code == 404 assert "no_access.html" in to_serve[2].file - assert to_serve[2].code == 403 + assert to_serve[2].code == 404 assert to_serve[3].code == 200 @@ -362,18 +361,18 @@ def test_sapphire_14(client, tmp_path): inc404 = _TestFile("inc_test/included_file_404.html") assert not (nest_path / "included_file_404.html").is_file() to_serve.append(inc404) - # test 403 - inc403 = _create_test( + # test 404 with file outside of include path + inc_ext = _create_test( "no_access.html", tmp_path, data=b"no_access", url_prefix="inc_test/../" ) assert (tmp_path / "no_access.html").is_file() - to_serve.append(inc403) + to_serve.append(inc_ext) # test file (used to keep sever job alive) test = _create_test("test_case.html", root_path) to_serve.append(test) # add include paths - smap.set_include("/", str(inc1_path)) # mount at '/' - smap.set_include("inc_test", str(inc2_path)) # mount at '/inc_test' + smap.set_include("/", inc1_path) # mount at '/' + smap.set_include("inc_test", inc2_path) # mount at '/inc_test' client.launch("127.0.0.1", serv.port, to_serve, in_order=True) status, served = serv.serve_path( root_path, server_map=smap, required_files=[x.file for x in to_serve] @@ -393,7 +392,7 @@ def test_sapphire_14(client, tmp_path): assert test.code == 200 assert nest_404.code == 404 assert inc404.code == 404 - assert inc403.code == 403 + assert inc_ext.code == 404 @mark.parametrize( diff --git a/sapphire/test_server_map.py b/sapphire/test_server_map.py index 6d44dd64..8cdf79ce 100644 --- a/sapphire/test_server_map.py +++ b/sapphire/test_server_map.py @@ -2,9 +2,18 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from pathlib import Path + from pytest import raises -from .server_map import InvalidURLError, MapCollisionError, Resource, ServerMap +from .server_map import ( + DynamicResource, + IncludeResource, + InvalidURLError, + MapCollisionError, + RedirectResource, + ServerMap, +) def test_servermap_01(): @@ -20,10 +29,10 @@ def test_servermap_02(tmp_path): srv_map = ServerMap() srv_map.set_dynamic_response("url_01", lambda _: 0, mime_type="test/type") assert len(srv_map.dynamic) == 1 + assert isinstance(srv_map.dynamic["url_01"], DynamicResource) assert "url_01" in srv_map.dynamic assert srv_map.dynamic["url_01"].mime == "test/type" assert callable(srv_map.dynamic["url_01"].target) - assert srv_map.dynamic["url_01"].type == Resource.URL_DYNAMIC srv_map.set_dynamic_response("url_02", lambda _: 0, mime_type="foo") assert len(srv_map.dynamic) == 2 assert not srv_map.include @@ -36,7 +45,7 @@ def test_servermap_02(tmp_path): srv_map.set_dynamic_response("x", lambda _: 0, None) # test detecting collisions with raises(MapCollisionError): - srv_map.set_include("url_01", str(tmp_path)) + srv_map.set_include("url_01", tmp_path) with raises(MapCollisionError): srv_map.set_redirect("url_01", "test_file") @@ -45,21 +54,22 @@ def test_servermap_03(tmp_path): """test ServerMap includes""" srv_map = ServerMap() with raises(IOError, match="Include path not found: no_dir"): - srv_map.set_include("test_url", "no_dir") + srv_map.set_include("test_url", Path("no_dir")) assert not srv_map.include - srv_map.set_include("url_01", str(tmp_path)) + srv_map.set_include("url_01", tmp_path) assert len(srv_map.include) == 1 + assert isinstance(srv_map.include["url_01"], IncludeResource) assert "url_01" in srv_map.include - assert srv_map.include["url_01"].target == str(tmp_path) + assert srv_map.include["url_01"].target == tmp_path # overwrite existing inc1 = tmp_path / "includes" / "a" inc1.mkdir(parents=True) - srv_map.set_include("url_01", str(inc1)) - assert srv_map.include["url_01"].target == str(inc1) + srv_map.set_include("url_01", inc1) + assert srv_map.include["url_01"].target == inc1 # add another inc2 = tmp_path / "includes" / "b" inc2.mkdir() - srv_map.set_include("url_02", str(inc2)) + srv_map.set_include("url_02", inc2) assert len(srv_map.include) == 2 assert not srv_map.dynamic assert not srv_map.redirect @@ -70,11 +80,11 @@ def test_servermap_03(tmp_path): srv_map.set_dynamic_response("url_01", lambda _: 0, mime_type="test/type") # test overlapping includes with raises(MapCollisionError, match=r"'url_01' and '\w+' include"): - srv_map.set_include("url_01", str(tmp_path)) + srv_map.set_include("url_01", tmp_path) inc3 = tmp_path / "includes" / "b" / "c" inc3.mkdir() with raises(MapCollisionError, match=r"'url_01' and '\w+' include"): - srv_map.set_include("url_01", str(inc3)) + srv_map.set_include("url_01", inc3) def test_servermap_04(tmp_path): @@ -82,6 +92,7 @@ def test_servermap_04(tmp_path): srv_map = ServerMap() srv_map.set_redirect("url_01", "test_file", required=True) assert len(srv_map.redirect) == 1 + assert isinstance(srv_map.redirect["url_01"], RedirectResource) assert "url_01" in srv_map.redirect assert srv_map.redirect["url_01"].target == "test_file" assert srv_map.redirect["url_01"].required @@ -96,7 +107,7 @@ def test_servermap_04(tmp_path): srv_map.set_redirect("x", None) # test detecting collisions with raises(MapCollisionError): - srv_map.set_include("url_01", str(tmp_path)) + srv_map.set_include("url_01", tmp_path) with raises(MapCollisionError): srv_map.set_dynamic_response("url_01", lambda _: 0, mime_type="test/type") diff --git a/sapphire/worker.py b/sapphire/worker.py index bb98488f..a90c751e 100644 --- a/sapphire/worker.py +++ b/sapphire/worker.py @@ -6,14 +6,19 @@ """ from logging import getLogger from re import compile as re_compile -from socket import SHUT_RDWR +from socket import SHUT_RDWR, socket from socket import timeout as sock_timeout # Py3.10 socket.timeout => TimeoutError from sys import exc_info from threading import Thread, ThreadError, active_count from time import sleep -from urllib.parse import quote, unquote, urlparse +from typing import Optional +from urllib.parse import ParseResult, quote, unquote, urlparse + +from .job import Job +from .server_map import DynamicResource, FileResource, RedirectResource + +# TODO: urlparse -> urlsplit -from .server_map import Resource __author__ = "Tyson Smith" __credits__ = ["Tyson Smith"] @@ -26,12 +31,12 @@ class Request: __slots__ = ("method", "url") - def __init__(self, method, url): + def __init__(self, method: str, url: ParseResult) -> None: self.method = method self.url = url @classmethod - def parse(cls, raw_data): + def parse(cls, raw_data: bytes) -> Optional["Request"]: assert isinstance(raw_data, bytes) req_match = cls.REQ_PATTERN.match(raw_data) if not req_match: @@ -64,12 +69,12 @@ class Worker: __slots__ = ("_conn", "_thread") - def __init__(self, conn, thread): + def __init__(self, conn: socket, thread: Thread) -> None: self._conn = conn - self._thread = thread + self._thread: Optional[Thread] = thread @staticmethod - def _200_header(c_length, c_type, encoding="ascii"): + def _200_header(c_length: int, c_type: str) -> bytes: assert c_type is not None data = ( "HTTP/1.1 200 OK\r\n" @@ -78,19 +83,19 @@ def _200_header(c_length, c_type, encoding="ascii"): f"Content-Type: {c_type}\r\n" "Connection: close\r\n\r\n" ) - return data.encode(encoding) + return data.encode(encoding="ascii") @staticmethod - def _307_redirect(redirect_to, encoding="ascii"): + def _307_redirect(redirect_to: str) -> bytes: data = ( "HTTP/1.1 307 Temporary Redirect\r\n" f"Location: {redirect_to}\r\n" "Connection: close\r\n\r\n" ) - return data.encode(encoding) + return data.encode(encoding="ascii") @staticmethod - def _4xx_page(code, hdr_msg, close=-1, encoding="ascii"): + def _4xx_page(code: int, hdr_msg: str, close: int = -1) -> bytes: if close < 0: content = f"

{code}!

" else: @@ -109,9 +114,9 @@ def _4xx_page(code, hdr_msg, close=-1, encoding="ascii"): "Content-Type: text/html\r\n" f"Connection: close\r\n\r\n{content}" ) - return data.encode(encoding) + return data.encode(encoding="ascii") - def close(self): + def close(self) -> None: # workers that are no longer running will have had close() called if self.is_alive(): # shutdown socket to avoid hang @@ -122,11 +127,11 @@ def close(self): LOG.debug("close - shutdown(): %s", exc) self._conn.close() - def is_alive(self): + def is_alive(self) -> bool: return self._thread is not None and self._thread.is_alive() @classmethod - def handle_request(cls, conn, serv_job): + def handle_request(cls, conn: socket, serv_job: Job) -> None: finish_job = False # call finish() on return try: # socket operations should not block forever @@ -162,9 +167,9 @@ def handle_request(cls, conn, serv_job): LOG.debug("lookup resource %r", request.url.path) resource = serv_job.lookup_resource(request.url.path) if resource: - if resource.type in (Resource.URL_FILE, Resource.URL_INCLUDE): + if isinstance(resource, FileResource): finish_job = serv_job.remove_pending(str(resource.target)) - elif resource.type in (Resource.URL_DYNAMIC, Resource.URL_REDIRECT): + elif isinstance(resource, (DynamicResource, RedirectResource)): finish_job = serv_job.remove_pending(request.url.path.lstrip("/")) else: # pragma: no cover # this should never happen @@ -186,7 +191,7 @@ def handle_request(cls, conn, serv_job): request.url.path[-40:], serv_job.pending, ) - elif resource.type == Resource.URL_REDIRECT: + elif isinstance(resource, RedirectResource): redirect_to = [quote(resource.target)] if request.url.query: LOG.debug("appending query %r", request.url.query) @@ -198,7 +203,7 @@ def handle_request(cls, conn, serv_job): resource.target, serv_job.pending, ) - elif resource.type == Resource.URL_DYNAMIC: + elif isinstance(resource, DynamicResource): # pass query string to callback data = resource.target(request.url.query) if not isinstance(data, bytes): @@ -211,16 +216,8 @@ def handle_request(cls, conn, serv_job): request.url.path, serv_job.pending, ) - elif serv_job.is_forbidden( - resource.target, is_include=resource.type == Resource.URL_INCLUDE - ): - # NOTE: this does info leak if files exist on disk. - # We could replace 403 with 404 if it turns out we care. - # However this is meant to only be accessible via localhost. - LOG.debug("target %r", str(resource.target)) - conn.sendall(cls._4xx_page(403, "Forbidden", serv_job.auto_close)) - LOG.debug("403 %r (%d to go)", request.url.path, serv_job.pending) else: + assert isinstance(resource, FileResource) # serve the file data_size = resource.target.stat().st_size LOG.debug( @@ -229,6 +226,7 @@ def handle_request(cls, conn, serv_job): resource.mime, resource.target, ) + assert resource.mime is not None with resource.target.open("rb") as in_fp: conn.sendall(cls._200_header(data_size, resource.mime)) offset = 0 @@ -257,7 +255,7 @@ def handle_request(cls, conn, serv_job): serv_job.finish() serv_job.worker_complete.set() - def join(self, timeout=30): + def join(self, timeout: float = 30) -> bool: assert timeout >= 0 if self._thread is not None: self._thread.join(timeout=timeout) @@ -266,7 +264,9 @@ def join(self, timeout=30): return self._thread is None @classmethod - def launch(cls, listen_sock, job, timeout=30): + def launch( + cls, listen_sock: socket, job: Job, timeout: float = 30 + ) -> Optional["Worker"]: assert timeout >= 0 assert job.accepting.is_set() conn = None