Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Peddle committed Dec 8, 2023
1 parent 325f772 commit d8559b9
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 45 deletions.
43 changes: 30 additions & 13 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import time
import os
from types import GeneratorType
from typing import Callable, Literal, Union
from typing import Callable
from dataclasses import dataclass
from flask import Flask, request, make_response, abort, Response as FlaskResponse
from huggingface_hub.file_download import uuid
import uuid
from werkzeug.serving import make_server
from threading import Thread, Lock
from queue import Queue as ThreadQueue
Expand Down Expand Up @@ -37,12 +37,16 @@ def __init__(self, response_queue):
t.start()

def _response_handler(self):
while True:
request_id, payload = self._response_queue.get()
with self._lock:
if request_id not in self._mailbox:
self._mailbox[request_id] = ThreadQueue()
self._mailbox[request_id].put(payload)
try:
while True:
request_id, payload = self._response_queue.get()
with self._lock:
if request_id not in self._mailbox:
self._mailbox[request_id] = ThreadQueue()
self._mailbox[request_id].put(payload)
except EOFError:
# queue closed, this happens when the server is shutting down
pass

def get_response(self, request_id):
with self._lock:
Expand Down Expand Up @@ -119,16 +123,22 @@ def __init__(self, name):
self._status = PotassiumStatus(
num_started_inference_requests=0,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=self._num_workers,
num_workers_started=0,
idle_start_timestamp=time.time(),
in_flight_request_start_times=[]
)

def _event_handler(self):
while True:
event = self._event_queue.get()
self._status = self._status.update(event)
try:
while True:
event = self._event_queue.get()
self._status = self._status.update(event)
except EOFError:
# this happens when the process is shutting down
pass


def init(self, func):
"""init runs once on server start, and is used to initialize the app's context.
Expand Down Expand Up @@ -210,6 +220,7 @@ def _create_flask_app(self):
def handle(path):
route = "/" + path
if route not in self._endpoints:
self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,))
abort(404)

endpoint = self._endpoints[route]
Expand All @@ -225,6 +236,7 @@ def handle(path):
except:
res = make_response()
res.status_code = 400
self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,))
return res

self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
Expand Down Expand Up @@ -257,7 +269,7 @@ def warm():

# a bit of a hack but we need to send a start and end event to the event queue
# in order to update the status the way the load balancer expects
self._event_queue.put((StatusEvent.INFERENCE_START, request_id))
self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
self._event_queue.put((StatusEvent.INFERENCE_END, request_id))
res = make_response({
"warm": True,
Expand Down Expand Up @@ -293,12 +305,17 @@ def _init_server(self):
Pool = ProcessPool
self._worker_pool = Pool(self._num_workers, init_worker, (index_queue, self._event_queue, self._response_queue, self._init_func, self._num_workers))

while True:
if self._status.num_workers_started == self._num_workers:
break
print(colored(f"Started {self._num_workers} workers", 'green'))

# serve runs the http server
def serve(self, host="0.0.0.0", port=8000):
print(colored("------\nStarting Potassium Server 🍌", 'yellow'))
self._init_server()
server = make_server(host, port, self._flask_app, threaded=True)
print(colored(f"Serving at http://{host}:{port}\n------", 'green'))
self._init_server()

server.serve_forever()

14 changes: 11 additions & 3 deletions potassium/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ class StatusEvent(Enum):
INFERENCE_START = "INFERENCE_START"
INFERENCE_END = "INFERENCE_END"
WORKER_STARTED = "WORKER_STARTED"
BAD_REQUEST_RECEIVED = "BAD_REQUEST_RECEIVED"

@dataclass
class PotassiumStatus():
"""PotassiumStatus is a simple class that represents the status of a Potassium app."""
num_started_inference_requests: int
num_completed_inference_requests: int
num_bad_requests: int
num_workers: int
num_workers_started: int
idle_start_timestamp: float
Expand All @@ -31,7 +33,7 @@ def gpu_available(self):

@property
def sequence_number(self):
return self.num_started_inference_requests
return self.num_started_inference_requests + self.num_bad_requests

@property
def idle_time(self):
Expand All @@ -48,7 +50,7 @@ def longest_inference_time(self):

return time.time() - oldest_start_time

def update(self, event):
def update(self, event) -> "PotassiumStatus":
event_type = event[0]
event_data = event[1:]
if event_type not in event_handlers:
Expand All @@ -60,6 +62,7 @@ def clone(self):
return PotassiumStatus(
self.num_started_inference_requests,
self.num_completed_inference_requests,
self.num_bad_requests,
self.num_workers,
self.num_workers_started,
self.idle_start_timestamp,
Expand Down Expand Up @@ -87,11 +90,16 @@ def handle_worker_started(status: PotassiumStatus):
status.num_workers_started += 1
return status

def handle_bad_request_received(status: PotassiumStatus):
status.num_bad_requests += 1
return status

event_handlers = {
StatusEvent.INFERENCE_REQUEST_RECEIVED: handle_inference_request_received,
StatusEvent.INFERENCE_START: handle_start_inference,
StatusEvent.INFERENCE_END: handle_end_inference,
StatusEvent.WORKER_STARTED: handle_worker_started
StatusEvent.WORKER_STARTED: handle_worker_started,
StatusEvent.BAD_REQUEST_RECEIVED: handle_bad_request_received
}


13 changes: 9 additions & 4 deletions potassium/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

class RequestHeaders():
def __init__(self, headers: Dict[str, str]):
self._headers = headers
self._headers = {}
for key in headers:
self._headers[self._normalize_key(key)] = headers[key]

def __getitem__(self, key):
def _normalize_key(self, key):
if not isinstance(key, str):
raise KeyError(key)
key = key.upper().replace("-", "_")

return key.upper().replace("-", "_")

def __getitem__(self, key):
print(self._headers)
key = self._normalize_key(key)
return self._headers[key]

def get(self, key, default=None):
Expand Down
8 changes: 6 additions & 2 deletions potassium/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def init_worker(index_queue, event_queue, response_queue, init_func, total_worke
stdout_redirect = FDRedirect(1)
stderr_redirect = FDRedirect(2)

stderr_redirect.set_prefix(f"[worker {worker_num}] ")
stdout_redirect.set_prefix(f"[worker {worker_num}] ")
if total_workers > 1:
stderr_redirect.set_prefix(f"[worker {worker_num}] ")
stdout_redirect.set_prefix(f"[worker {worker_num}] ")

# check if the init function takes in a worker number
try:
Expand Down Expand Up @@ -132,6 +133,9 @@ def run_worker(func, request, internal_id, use_response=False):
worker.response_queue.put((stream_id, None))


if worker.total_workers == 1:
worker.stderr_redirect.set_prefix("")
worker.stdout_redirect.set_prefix("")

worker.event_queue.put((StatusEvent.INFERENCE_END, internal_id))

60 changes: 60 additions & 0 deletions tests/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import potassium

potassium_test_app = potassium.Potassium("test_app")

@potassium_test_app.init
def init():
return {}

@potassium_test_app.handler()
def handler(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
json={"hello": "root"},
status=200
)

@potassium_test_app.handler("/some_path")
def handler2(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
json={"hello": "some_path"},
status=200
)

@potassium_test_app.handler("/some_binary_response")
def handler3(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
body=b"hello",
status=200,
headers={"Content-Type": "application/octet-stream"}
)

@potassium_test_app.handler("/some_path_byte_stream_response")
def handler4(context: dict, request: potassium.Request) -> potassium.Response:
def stream():
yield b"hello"
yield b"world"

return potassium.Response(
body=stream(),
status=200,
headers={"Content-Type": "application/octet-stream"}
)

@potassium_test_app.handler("/some_path/child_path")
def handler2_id(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
json={"hello": f"some_path/child_path"},
status=200
)

@potassium_test_app.handler("/some_headers_request")
def handler5(context: dict, request: potassium.Request) -> potassium.Response:
assert request.headers["A"] == "a"
assert request.headers["B"] == "b"
assert request.headers["X-Banana-Request-Id"] == request.id
return potassium.Response(
headers={"A": "a", "B": "b", "X-Banana-Request-Id": request.id},
json={"hello": "some_headers_request", "id": request.id},
status=200
)

Loading

0 comments on commit d8559b9

Please sign in to comment.