diff --git a/example.py b/example.py index e615dd7..96c458c 100644 --- a/example.py +++ b/example.py @@ -37,6 +37,7 @@ def stream(): yield f"{i}\n" time.sleep(1) + return Response( body=stream(), status=200, diff --git a/potassium/potassium.py b/potassium/potassium.py index b4ae24a..572696b 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -1,6 +1,9 @@ +from enum import Enum import time import os from types import GeneratorType +from typing import Callable, Literal, Union +from dataclasses import dataclass from flask import Flask, request, make_response, abort, Response as FlaskResponse from huggingface_hub.file_download import uuid from werkzeug.serving import make_server @@ -13,7 +16,16 @@ from .status import PotassiumStatus, StatusEvent from .worker import run_worker, init_worker from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException -from .types import Request, Endpoint, RequestHeaders, Response +from .types import Request, RequestHeaders, Response + +class HandlerType(Enum): + HANDLER = "HANDLER" + BACKGROUND = "BACKGROUND" + +@dataclass +class Endpoint(): + type: HandlerType + func: Callable class ResponseMailbox(): def __init__(self, response_queue): @@ -149,10 +161,7 @@ def _standardize_route(route): return route - # handler is a blocking http POST handler - def handler(self, route: str = "/"): - "handler is a blocking http POST handler" - + def _base_decorator(self, route: str, handler_type: HandlerType): route = self._standardize_route(route) if route in self._endpoints: raise RouteAlreadyInUseException() @@ -173,27 +182,19 @@ def wrapper(context, request): return out - self._endpoints[route] = Endpoint(type="handler", func=wrapper) + self._endpoints[route] = Endpoint(type=handler_type, func=wrapper) return wrapper return actual_decorator + # handler is a blocking http POST handler + def handler(self, route: str = "/"): + "handler is a blocking http POST handler" + return self._base_decorator(route, HandlerType.HANDLER) + # background is a non-blocking http POST handler def background(self, route: str = "/"): "background is a non-blocking http POST handler" - route = self._standardize_route(route) - if route in self._endpoints: - raise RouteAlreadyInUseException() - - def actual_decorator(func): - @functools.wraps(func) - def wrapper(request): - # send in app's stateful context if GPU, and the request - return func(self._context, request) - - self._endpoints[route] = Endpoint( - type="background", func=wrapper) - return wrapper - return actual_decorator + return self._base_decorator(route, HandlerType.BACKGROUND) def test_client(self): "test_client returns a Flask test client for the app" @@ -232,7 +233,7 @@ def handle(path): # use an internal id for critical path to prevent user from accidentally # breaking things by sending multiple requests with the same id internal_id = str(uuid.uuid4()) - if endpoint.type == "handler": + if endpoint.type == HandlerType.HANDLER: self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id, True)) resp = self._response_mailbox.get_response(internal_id) @@ -241,7 +242,7 @@ def handle(path): status=resp.status, headers=resp.headers ) - elif endpoint.type == "background": + elif endpoint.type == HandlerType.BACKGROUND: self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id)) flask_response = make_response({'started': True}) diff --git a/potassium/types.py b/potassium/types.py index 6358cff..c2d758b 100644 --- a/potassium/types.py +++ b/potassium/types.py @@ -2,11 +2,6 @@ from typing import Any, Callable, Dict, Generator, Optional, Union, Generator, Optional, Union import json as jsonlib -@dataclass -class Endpoint(): - type: str - func: Callable - class RequestHeaders(): def __init__(self, headers: Dict[str, str]): self._headers = headers diff --git a/potassium/worker.py b/potassium/worker.py index ee3a65e..94d6525 100644 --- a/potassium/worker.py +++ b/potassium/worker.py @@ -51,16 +51,27 @@ class Worker(): stderr_redirect: FDRedirect stdout_redirect: FDRedirect - def init_worker(index_queue, event_queue, response_queue, init_func): global worker worker_num = index_queue.get() + stdout_redirect = FDRedirect(1) + stderr_redirect = FDRedirect(2) + + stderr_redirect.set_prefix(f"[worker {worker_num}] ") + stdout_redirect.set_prefix(f"[worker {worker_num}] ") + # check if the init function takes in a worker number - if len(inspect.signature(init_func).parameters) == 0: - context = init_func() - else: - context = init_func(worker_num) + try: + if len(inspect.signature(init_func).parameters) == 0: + context = init_func() + else: + context = init_func(worker_num) + except Exception as e: + tb_str = traceback.format_exc() + print(colored(tb_str, "red")) + raise e + event_queue.put((StatusEvent.WORKER_STARTED,)) @@ -68,8 +79,8 @@ def init_worker(index_queue, event_queue, response_queue, init_func): context, event_queue, response_queue, - FDRedirect(1), - FDRedirect(2) + stdout_redirect, + stderr_redirect ) def run_worker(func, request, internal_id, use_response=False):