diff --git a/potassium/potassium.py b/potassium/potassium.py index 805040a..b79ddd4 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -2,10 +2,10 @@ import time import os from types import GeneratorType -from typing import Callable, Literal, Union +from typing import Callable from dataclasses import dataclass from flask import Flask, request, make_response, abort, Response as FlaskResponse -from huggingface_hub.file_download import uuid +import uuid from werkzeug.serving import make_server from threading import Thread, Lock from queue import Queue as ThreadQueue @@ -37,12 +37,16 @@ def __init__(self, response_queue): t.start() def _response_handler(self): - while True: - request_id, payload = self._response_queue.get() - with self._lock: - if request_id not in self._mailbox: - self._mailbox[request_id] = ThreadQueue() - self._mailbox[request_id].put(payload) + try: + while True: + request_id, payload = self._response_queue.get() + with self._lock: + if request_id not in self._mailbox: + self._mailbox[request_id] = ThreadQueue() + self._mailbox[request_id].put(payload) + except EOFError: + # queue closed, this happens when the server is shutting down + pass def get_response(self, request_id): with self._lock: @@ -119,6 +123,7 @@ def __init__(self, name): self._status = PotassiumStatus( num_started_inference_requests=0, num_completed_inference_requests=0, + num_bad_requests=0, num_workers=self._num_workers, num_workers_started=0, idle_start_timestamp=time.time(), @@ -126,9 +131,14 @@ def __init__(self, name): ) def _event_handler(self): - while True: - event = self._event_queue.get() - self._status = self._status.update(event) + try: + while True: + event = self._event_queue.get() + self._status = self._status.update(event) + except EOFError: + # this happens when the process is shutting down + pass + def init(self, func): """init runs once on server start, and is used to initialize the app's context. @@ -210,6 +220,7 @@ def _create_flask_app(self): def handle(path): route = "/" + path if route not in self._endpoints: + self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,)) abort(404) endpoint = self._endpoints[route] @@ -225,6 +236,7 @@ def handle(path): except: res = make_response() res.status_code = 400 + self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,)) return res self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) @@ -257,7 +269,7 @@ def warm(): # a bit of a hack but we need to send a start and end event to the event queue # in order to update the status the way the load balancer expects - self._event_queue.put((StatusEvent.INFERENCE_START, request_id)) + self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) self._event_queue.put((StatusEvent.INFERENCE_END, request_id)) res = make_response({ "warm": True, @@ -293,12 +305,17 @@ def _init_server(self): Pool = ProcessPool self._worker_pool = Pool(self._num_workers, init_worker, (index_queue, self._event_queue, self._response_queue, self._init_func, self._num_workers)) + while True: + if self._status.num_workers_started == self._num_workers: + break + print(colored(f"Started {self._num_workers} workers", 'green')) + # serve runs the http server def serve(self, host="0.0.0.0", port=8000): print(colored("------\nStarting Potassium Server 🍌", 'yellow')) + self._init_server() server = make_server(host, port, self._flask_app, threaded=True) print(colored(f"Serving at http://{host}:{port}\n------", 'green')) - self._init_server() server.serve_forever() diff --git a/potassium/status.py b/potassium/status.py index 3693866..55a1c04 100644 --- a/potassium/status.py +++ b/potassium/status.py @@ -10,12 +10,14 @@ class StatusEvent(Enum): INFERENCE_START = "INFERENCE_START" INFERENCE_END = "INFERENCE_END" WORKER_STARTED = "WORKER_STARTED" + BAD_REQUEST_RECEIVED = "BAD_REQUEST_RECEIVED" @dataclass class PotassiumStatus(): """PotassiumStatus is a simple class that represents the status of a Potassium app.""" num_started_inference_requests: int num_completed_inference_requests: int + num_bad_requests: int num_workers: int num_workers_started: int idle_start_timestamp: float @@ -31,7 +33,7 @@ def gpu_available(self): @property def sequence_number(self): - return self.num_started_inference_requests + return self.num_started_inference_requests + self.num_bad_requests @property def idle_time(self): @@ -48,7 +50,7 @@ def longest_inference_time(self): return time.time() - oldest_start_time - def update(self, event): + def update(self, event) -> "PotassiumStatus": event_type = event[0] event_data = event[1:] if event_type not in event_handlers: @@ -60,6 +62,7 @@ def clone(self): return PotassiumStatus( self.num_started_inference_requests, self.num_completed_inference_requests, + self.num_bad_requests, self.num_workers, self.num_workers_started, self.idle_start_timestamp, @@ -87,11 +90,16 @@ def handle_worker_started(status: PotassiumStatus): status.num_workers_started += 1 return status +def handle_bad_request_received(status: PotassiumStatus): + status.num_bad_requests += 1 + return status + event_handlers = { StatusEvent.INFERENCE_REQUEST_RECEIVED: handle_inference_request_received, StatusEvent.INFERENCE_START: handle_start_inference, StatusEvent.INFERENCE_END: handle_end_inference, - StatusEvent.WORKER_STARTED: handle_worker_started + StatusEvent.WORKER_STARTED: handle_worker_started, + StatusEvent.BAD_REQUEST_RECEIVED: handle_bad_request_received } diff --git a/potassium/types.py b/potassium/types.py index c2d758b..11b8b47 100644 --- a/potassium/types.py +++ b/potassium/types.py @@ -4,13 +4,18 @@ class RequestHeaders(): def __init__(self, headers: Dict[str, str]): - self._headers = headers + self._headers = {} + for key in headers: + self._headers[self._normalize_key(key)] = headers[key] - def __getitem__(self, key): + def _normalize_key(self, key): if not isinstance(key, str): raise KeyError(key) - key = key.upper().replace("-", "_") - + return key.upper().replace("-", "_") + + def __getitem__(self, key): + print(self._headers) + key = self._normalize_key(key) return self._headers[key] def get(self, key, default=None): diff --git a/potassium/worker.py b/potassium/worker.py index 66ceaed..612e48f 100644 --- a/potassium/worker.py +++ b/potassium/worker.py @@ -60,8 +60,9 @@ def init_worker(index_queue, event_queue, response_queue, init_func, total_worke stdout_redirect = FDRedirect(1) stderr_redirect = FDRedirect(2) - stderr_redirect.set_prefix(f"[worker {worker_num}] ") - stdout_redirect.set_prefix(f"[worker {worker_num}] ") + if total_workers > 1: + stderr_redirect.set_prefix(f"[worker {worker_num}] ") + stdout_redirect.set_prefix(f"[worker {worker_num}] ") # check if the init function takes in a worker number try: @@ -132,6 +133,9 @@ def run_worker(func, request, internal_id, use_response=False): worker.response_queue.put((stream_id, None)) + if worker.total_workers == 1: + worker.stderr_redirect.set_prefix("") + worker.stdout_redirect.set_prefix("") worker.event_queue.put((StatusEvent.INFERENCE_END, internal_id)) diff --git a/tests/app.py b/tests/app.py new file mode 100644 index 0000000..de16028 --- /dev/null +++ b/tests/app.py @@ -0,0 +1,60 @@ +import potassium + +potassium_test_app = potassium.Potassium("test_app") + +@potassium_test_app.init +def init(): + return {} + +@potassium_test_app.handler() +def handler(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": "root"}, + status=200 + ) + +@potassium_test_app.handler("/some_path") +def handler2(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": "some_path"}, + status=200 + ) + +@potassium_test_app.handler("/some_binary_response") +def handler3(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + body=b"hello", + status=200, + headers={"Content-Type": "application/octet-stream"} + ) + +@potassium_test_app.handler("/some_path_byte_stream_response") +def handler4(context: dict, request: potassium.Request) -> potassium.Response: + def stream(): + yield b"hello" + yield b"world" + + return potassium.Response( + body=stream(), + status=200, + headers={"Content-Type": "application/octet-stream"} + ) + +@potassium_test_app.handler("/some_path/child_path") +def handler2_id(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": f"some_path/child_path"}, + status=200 + ) + +@potassium_test_app.handler("/some_headers_request") +def handler5(context: dict, request: potassium.Request) -> potassium.Response: + assert request.headers["A"] == "a" + assert request.headers["B"] == "b" + assert request.headers["X-Banana-Request-Id"] == request.id + return potassium.Response( + headers={"A": "a", "B": "b", "X-Banana-Request-Id": request.id}, + json={"hello": "some_headers_request", "id": request.id}, + status=200 + ) + diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index edecd06..89b451a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,11 +4,9 @@ import pytest import potassium - def test_handler(): app = potassium.Potassium("my_app") - global init @app.init def init(): return {} @@ -102,10 +100,19 @@ def handler5(context: dict, request: potassium.Request) -> potassium.Response: assert res.status_code == 400 # check status - res = client.get("/__status__") - assert res.status_code == 200 - assert res.json is not None - assert res.json["gpu_available"] == True + count = 0 + while True: + res = client.get("/__status__") + assert res.status_code == 200 + assert res.json is not None + + if res.json["gpu_available"] == True: + break + elif count > 10: + assert False, "GPU never became available" + else: + time.sleep(0.1) + count += 1 # parameterized test for path collisions @pytest.mark.parametrize("paths", [ @@ -115,7 +122,6 @@ def handler5(context: dict, request: potassium.Request) -> potassium.Response: def test_path_collision(paths): app = potassium.Potassium("my_app") - global init @app.init def init(): return {} @@ -141,7 +147,6 @@ def test_status(): resolve_background_condition = threading.Condition() - global init @app.init def init(): return {} @@ -199,6 +204,8 @@ def background(context: dict, request: potassium.Request): res = client.post("/this_path_does_not_exist", json={}) assert res.status_code == 404 + # takes a split second for the status to update + time.sleep(0.1) res = client.get("/__status__", json={}) assert res.status_code == 200 assert res.json is not None @@ -212,7 +219,6 @@ def test_wait_for_background_task(): order_of_execution_queue = queue.Queue() resolve_background_condition = threading.Condition() - global init @app.init def init(): return {} @@ -223,35 +229,40 @@ def background(context: dict, request: potassium.Request): resolve_background_condition.wait() - def wait_for_background_task(): - app._read_event_chan() - order_of_execution_queue.put("background_task_completed") - - thread = threading.Thread(target=wait_for_background_task) - thread.start() - client = app.test_client() # send background post in separate thread - order_of_execution_queue.put("send_background_task") res = client.post("/background", json={}) assert res.status_code == 200 + + time.sleep(0.1) + + res = client.get("/__status__", json={}) + + assert res.status_code == 200 + assert res.json is not None + assert res.json["gpu_available"] == False + assert res.json["sequence_number"] == 1 + # notify background thread to continue with resolve_background_condition: resolve_background_condition.notify() - thread.join() + time.sleep(0.1) + + res = client.get("/__status__", json={}) + + assert res.status_code == 200 + assert res.json is not None + assert res.json["gpu_available"] == True + assert res.json["sequence_number"] == 1 + - # assert order of execution - assert order_of_execution_queue.get() == "send_background_task" - assert order_of_execution_queue.get() == "background_task_completed" def test_warmup(): app = potassium.Potassium("my_app") - - global init @app.init def init(): return {} @@ -266,6 +277,7 @@ def handler(context: dict, request: potassium.Request) -> potassium.Response: assert res.status_code == 200 assert res.json == {"warm": True} + time.sleep(0.1) res = client.get("/__status__", json={}) assert res.status_code == 200 assert res.json is not None