From 543d9c95756d0d05b0ed3e06b7c031b753f8f1ab Mon Sep 17 00:00:00 2001 From: Vignesh Rao Date: Sat, 7 Sep 2024 14:27:35 -0500 Subject: [PATCH] Fix relative imports to get CLI working again --- pyninja/__init__.py | 6 ++---- pyninja/auth.py | 2 +- pyninja/database.py | 2 +- pyninja/main.py | 8 ++++---- pyninja/models.py | 2 +- pyninja/monitor/__init__.py | 8 +------- pyninja/monitor/authenticator.py | 22 +++++++++++----------- pyninja/monitor/config.py | 6 ++++++ pyninja/monitor/routes.py | 28 ++++++++++++++++------------ pyninja/rate_limit.py | 2 +- pyninja/routers.py | 2 +- pyninja/service.py | 2 +- pyninja/squire.py | 14 +++++++------- pyninja/version.py | 1 + pyproject.toml | 6 +++--- 15 files changed, 57 insertions(+), 54 deletions(-) create mode 100644 pyninja/version.py diff --git a/pyninja/__init__.py b/pyninja/__init__.py index 5ddacea..7638abd 100644 --- a/pyninja/__init__.py +++ b/pyninja/__init__.py @@ -5,9 +5,7 @@ import click -from pyninja.main import start # noqa: F401 - -version = "0.0.4" +from .main import start, version @click.command() @@ -50,7 +48,7 @@ def commandline(*args, **kwargs) -> None: for k, v in options.items() ) if kwargs.get("version"): - click.echo(f"PyNinja {version}") + click.echo(f"PyNinja {version.__version__}") sys.exit(0) if kwargs.get("help"): click.echo( diff --git a/pyninja/auth.py b/pyninja/auth.py index 287dc45..1d097ee 100644 --- a/pyninja/auth.py +++ b/pyninja/auth.py @@ -7,7 +7,7 @@ from fastapi import Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from pyninja import database, exceptions, models +from . import database, exceptions, models LOGGER = logging.getLogger("uvicorn.default") EPOCH = lambda: int(time.time()) # noqa: E731 diff --git a/pyninja/database.py b/pyninja/database.py index 05116bd..88e88ca 100644 --- a/pyninja/database.py +++ b/pyninja/database.py @@ -1,4 +1,4 @@ -from pyninja import models +from . import models def get_record(host: str) -> int | None: diff --git a/pyninja/main.py b/pyninja/main.py index 3c27c6d..0c38648 100644 --- a/pyninja/main.py +++ b/pyninja/main.py @@ -4,8 +4,8 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, RedirectResponse -import pyninja -from pyninja import exceptions, models, monitor, routers, squire +from . import exceptions, models, routers, squire, version +from .monitor.config import static LOGGER = logging.getLogger("uvicorn.default") @@ -25,7 +25,7 @@ async def redirect_exception_handler( """ LOGGER.debug("Exception headers: %s", request.headers) LOGGER.debug("Exception cookies: %s", request.cookies) - if request.url.path == monitor.config.static.login_endpoint: + if request.url.path == static.login_endpoint: response = JSONResponse( content={"redirect_url": exception.location}, status_code=200 ) @@ -68,7 +68,7 @@ def start(**kwargs) -> None: routes=routers.get_all_routes(), title="PyNinja", description="Lightweight OS-agnostic service monitoring API", - version=pyninja.version, + version=version.__version__, ) app.add_exception_handler( exc_class_or_status_code=exceptions.RedirectException, diff --git a/pyninja/models.py b/pyninja/models.py index 3d2b2df..6ed7332 100644 --- a/pyninja/models.py +++ b/pyninja/models.py @@ -16,7 +16,7 @@ from pydantic_core import InitErrorDetails from pydantic_settings import BaseSettings -from pyninja import exceptions +from . import exceptions OPERATING_SYSTEM = platform.system() diff --git a/pyninja/monitor/__init__.py b/pyninja/monitor/__init__.py index 172bc30..706f8b1 100644 --- a/pyninja/monitor/__init__.py +++ b/pyninja/monitor/__init__.py @@ -1,15 +1,9 @@ -import os from typing import List from fastapi import Depends from fastapi.routing import APIRoute, APIWebSocketRoute -from fastapi.templating import Jinja2Templates -from pyninja.monitor import authenticator, config, routes, secure # noqa: F401 - -templates = Jinja2Templates( - directory=os.path.join(os.path.dirname(__file__), "templates") -) +from . import authenticator, config, routes, secure # noqa: F401 def get_all_monitor_routes( diff --git a/pyninja/monitor/authenticator.py b/pyninja/monitor/authenticator.py index 544626d..ab10324 100644 --- a/pyninja/monitor/authenticator.py +++ b/pyninja/monitor/authenticator.py @@ -8,8 +8,8 @@ from fastapi import Request, status from fastapi.responses import HTMLResponse -import pyninja -from pyninja import exceptions, models, monitor, squire +from .. import exceptions, models, squire, version +from . import config, secure LOGGER = logging.getLogger("uvicorn.default") @@ -55,9 +55,9 @@ async def extract_credentials(authorization: str, host: str) -> List[str]: """ if not authorization: await raise_error(host) - decoded_auth = await monitor.secure.base64_decode(authorization) + decoded_auth = await secure.base64_decode(authorization) # convert hex to a string - auth = await monitor.secure.hex_decode(decoded_auth) + auth = await secure.hex_decode(decoded_auth) return auth.split(",") @@ -70,13 +70,13 @@ async def verify_login(authorization: str, host: str) -> Dict[str, Union[str, in """ username, signature, timestamp = await extract_credentials(authorization, host) if secrets.compare_digest(username, models.env.monitor_username): - hex_user = await monitor.secure.hex_encode(models.env.monitor_username) - hex_pass = await monitor.secure.hex_encode(models.env.monitor_password) + hex_user = await secure.hex_encode(models.env.monitor_username) + hex_pass = await secure.hex_encode(models.env.monitor_password) else: LOGGER.warning("User '%s' not allowed", username) await raise_error(host) message = f"{hex_user}{hex_pass}{timestamp}" - expected_signature = await monitor.secure.calculate_hash(message) + expected_signature = await secure.calculate_hash(message) if secrets.compare_digest(signature, expected_signature): models.ws_session.invalid[host] = 0 key = squire.keygen() @@ -97,7 +97,7 @@ async def generate_cookie(auth_payload: dict) -> Dict[str, str | bool | int]: Dict[str, str | bool | int]: Returns a dictionary with cookie details """ - expiration = await monitor.config.get_expiry( + expiration = await config.get_expiry( lease_start=auth_payload["timestamp"], lease_duration=models.env.monitor_session ) LOGGER.info( @@ -128,13 +128,13 @@ async def session_error( HTMLResponse: Returns an HTML response templated using Jinja2. """ - return monitor.templates.TemplateResponse( + return config.templates.TemplateResponse( name="session.html", context={ "request": request, - "signin": monitor.config.static.login_endpoint, + "signin": config.static.login_endpoint, "reason": error.detail, - "version": f"v{pyninja.version}", + "version": f"v{version.__version__}", }, ) diff --git a/pyninja/monitor/config.py b/pyninja/monitor/config.py index fa64f6e..d787723 100644 --- a/pyninja/monitor/config.py +++ b/pyninja/monitor/config.py @@ -1,8 +1,14 @@ +import os import time from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates from pydantic import BaseModel +templates = Jinja2Templates( + directory=os.path.join(os.path.dirname(__file__), "templates") +) + async def clear_session(response: HTMLResponse) -> HTMLResponse: """Clear the session token from the response. diff --git a/pyninja/monitor/routes.py b/pyninja/monitor/routes.py index 6322ac6..3caf4ce 100644 --- a/pyninja/monitor/routes.py +++ b/pyninja/monitor/routes.py @@ -7,8 +7,7 @@ from fastapi.responses import HTMLResponse, JSONResponse from fastapi.websockets import WebSocket, WebSocketDisconnect, WebSocketState -import pyninja -from pyninja import exceptions, models, monitor, squire +from .. import exceptions, models, monitor, squire, version LOGGER = logging.getLogger("uvicorn.default") @@ -24,12 +23,12 @@ async def error_endpoint(request: Request) -> HTMLResponse: Returns an HTML response templated using Jinja2. """ return await monitor.config.clear_session( - monitor.templates.TemplateResponse( + monitor.config.templates.TemplateResponse( name="unauthorized.html", context={ "request": request, "signin": monitor.config.static.login_endpoint, - "version": f"v{pyninja.version}", + "version": f"v{version.__version__}", }, ) ) @@ -52,14 +51,14 @@ async def logout_endpoint(request: Request) -> HTMLResponse: response = await monitor.authenticator.session_error(request, error) else: models.ws_session.client_auth.pop(request.client.host) - response = monitor.templates.TemplateResponse( + response = monitor.config.templates.TemplateResponse( name="logout.html", context={ "request": request, "detail": "Session Expired", "signin": monitor.config.static.login_endpoint, "show_login": True, - "version": f"v{pyninja.version}", + "version": f"v{version.__version__}", }, ) return await monitor.config.clear_session(response) @@ -109,7 +108,7 @@ async def monitor_endpoint(request: Request, session_token: str = Cookie(None)): await monitor.authenticator.session_error(request, error) ) else: - return monitor.templates.TemplateResponse( + return monitor.config.templates.TemplateResponse( name="main.html", context=dict( request=request, @@ -118,12 +117,12 @@ async def monitor_endpoint(request: Request, session_token: str = Cookie(None)): ), ) else: - return monitor.templates.TemplateResponse( + return monitor.config.templates.TemplateResponse( name="index.html", context={ "request": request, "signin": monitor.config.static.login_endpoint, - "version": f"v{pyninja.version}", + "version": f"v{version.__version__}", }, ) @@ -136,9 +135,14 @@ async def websocket_endpoint(websocket: WebSocket, session_token: str = Cookie(N session_token: Session token set after verifying username and password. """ await websocket.accept() - session_validity = await monitor.authenticator.validate_session( - websocket.client.host, session_token - ) + try: + session_validity = await monitor.authenticator.validate_session( + websocket.client.host, session_token + ) + except exceptions.SessionError as error: + await websocket.send_text(error.__str__()) + await websocket.close() + return if not session_validity: await websocket.send_text("Unauthorized") await websocket.close() diff --git a/pyninja/rate_limit.py b/pyninja/rate_limit.py index 29a8467..aca99a9 100644 --- a/pyninja/rate_limit.py +++ b/pyninja/rate_limit.py @@ -4,7 +4,7 @@ from fastapi import Request -from pyninja import exceptions, models +from . import exceptions, models class RateLimiter: diff --git a/pyninja/routers.py b/pyninja/routers.py index 349145f..15b47bf 100644 --- a/pyninja/routers.py +++ b/pyninja/routers.py @@ -10,7 +10,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBasic, HTTPBearer from pydantic import PositiveFloat, PositiveInt -from pyninja import ( +from . import ( auth, dockerized, exceptions, diff --git a/pyninja/service.py b/pyninja/service.py index 03e1df8..4467dd8 100644 --- a/pyninja/service.py +++ b/pyninja/service.py @@ -2,7 +2,7 @@ import subprocess from http import HTTPStatus -from pyninja import models +from . import models LOGGER = logging.getLogger("uvicorn.default") diff --git a/pyninja/squire.py b/pyninja/squire.py index 6bad7f5..6c83af0 100644 --- a/pyninja/squire.py +++ b/pyninja/squire.py @@ -14,7 +14,7 @@ import yaml from pydantic import PositiveFloat, PositiveInt -from pyninja.models import EnvConfig +from . import models LOGGER = logging.getLogger("uvicorn.default") IP_REGEX = re.compile( @@ -146,7 +146,7 @@ def process_command( return result -def envfile_loader(filename: str | os.PathLike) -> EnvConfig: +def envfile_loader(filename: str | os.PathLike) -> models.EnvConfig: """Loads environment variables based on filetypes. Args: @@ -160,24 +160,24 @@ def envfile_loader(filename: str | os.PathLike) -> EnvConfig: if env_file.suffix.lower() == ".json": with open(env_file) as stream: env_data = json.load(stream) - return EnvConfig(**{k.lower(): v for k, v in env_data.items()}) + return models.EnvConfig(**{k.lower(): v for k, v in env_data.items()}) elif env_file.suffix.lower() in (".yaml", ".yml"): with open(env_file) as stream: env_data = yaml.load(stream, yaml.FullLoader) - return EnvConfig(**{k.lower(): v for k, v in env_data.items()}) + return models.EnvConfig(**{k.lower(): v for k, v in env_data.items()}) elif not env_file.suffix or env_file.suffix.lower() in ( ".text", ".txt", "", ): - return EnvConfig.from_env_file(env_file) + return models.EnvConfig.from_env_file(env_file) else: raise ValueError( "\n\tUnsupported format for 'env_file', can be one of (.json, .yaml, .yml, .txt, .text, or null)" ) -def load_env(**kwargs) -> EnvConfig: +def load_env(**kwargs) -> models.EnvConfig: """Merge env vars from env_file with kwargs, giving priority to kwargs. See Also: @@ -194,7 +194,7 @@ def load_env(**kwargs) -> EnvConfig: else: file_env = {} merged_env = {**file_env, **kwargs} - return EnvConfig(**merged_env) + return models.EnvConfig(**merged_env) def keygen() -> str: diff --git a/pyninja/version.py b/pyninja/version.py new file mode 100644 index 0000000..81f0fde --- /dev/null +++ b/pyninja/version.py @@ -0,0 +1 @@ +__version__ = "0.0.4" diff --git a/pyproject.toml b/pyproject.toml index e227cb2..bdb67cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,12 @@ keywords = ["service-monitor", "PyNinja"] requires-python = ">=3.10" [tool.setuptools] -packages = ["pyninja"] +packages = ["pyninja", "pyninja.monitor", "pyninja.monitor.templates"] [tool.setuptools.package-data] -"pyninja" = ["index.html"] +"pyninja.monitor.templates" = ["*.html"] [tool.setuptools.dynamic] -version = {attr = "pyninja.version"} +version = {attr = "pyninja.version.__version__"} dependencies = { file = ["requirements.txt"] } [project.optional-dependencies]