Skip to content

Commit

Permalink
fix some failure cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Peddle committed Dec 7, 2023
1 parent 9f4e0f7 commit 9241ee9
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 34 deletions.
1 change: 1 addition & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def stream():
yield f"{i}\n"
time.sleep(1)


return Response(
body=stream(),
status=200,
Expand Down
45 changes: 23 additions & 22 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from enum import Enum
import time
import os
from types import GeneratorType
from typing import Callable, Literal, Union
from dataclasses import dataclass
from flask import Flask, request, make_response, abort, Response as FlaskResponse
from huggingface_hub.file_download import uuid
from werkzeug.serving import make_server
Expand All @@ -13,7 +16,16 @@
from .status import PotassiumStatus, StatusEvent
from .worker import run_worker, init_worker
from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException
from .types import Request, Endpoint, RequestHeaders, Response
from .types import Request, RequestHeaders, Response

class HandlerType(Enum):
HANDLER = "HANDLER"
BACKGROUND = "BACKGROUND"

@dataclass
class Endpoint():
type: HandlerType
func: Callable

class ResponseMailbox():
def __init__(self, response_queue):
Expand Down Expand Up @@ -149,10 +161,7 @@ def _standardize_route(route):

return route

# handler is a blocking http POST handler
def handler(self, route: str = "/"):
"handler is a blocking http POST handler"

def _base_decorator(self, route: str, handler_type: HandlerType):
route = self._standardize_route(route)
if route in self._endpoints:
raise RouteAlreadyInUseException()
Expand All @@ -173,27 +182,19 @@ def wrapper(context, request):
return out


self._endpoints[route] = Endpoint(type="handler", func=wrapper)
self._endpoints[route] = Endpoint(type=handler_type, func=wrapper)
return wrapper
return actual_decorator

# handler is a blocking http POST handler
def handler(self, route: str = "/"):
"handler is a blocking http POST handler"
return self._base_decorator(route, HandlerType.HANDLER)

# background is a non-blocking http POST handler
def background(self, route: str = "/"):
"background is a non-blocking http POST handler"
route = self._standardize_route(route)
if route in self._endpoints:
raise RouteAlreadyInUseException()

def actual_decorator(func):
@functools.wraps(func)
def wrapper(request):
# send in app's stateful context if GPU, and the request
return func(self._context, request)

self._endpoints[route] = Endpoint(
type="background", func=wrapper)
return wrapper
return actual_decorator
return self._base_decorator(route, HandlerType.BACKGROUND)

def test_client(self):
"test_client returns a Flask test client for the app"
Expand Down Expand Up @@ -232,7 +233,7 @@ def handle(path):
# use an internal id for critical path to prevent user from accidentally
# breaking things by sending multiple requests with the same id
internal_id = str(uuid.uuid4())
if endpoint.type == "handler":
if endpoint.type == HandlerType.HANDLER:
self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id, True))
resp = self._response_mailbox.get_response(internal_id)

Expand All @@ -241,7 +242,7 @@ def handle(path):
status=resp.status,
headers=resp.headers
)
elif endpoint.type == "background":
elif endpoint.type == HandlerType.BACKGROUND:
self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id))

flask_response = make_response({'started': True})
Expand Down
5 changes: 0 additions & 5 deletions potassium/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
from typing import Any, Callable, Dict, Generator, Optional, Union, Generator, Optional, Union
import json as jsonlib

@dataclass
class Endpoint():
type: str
func: Callable

class RequestHeaders():
def __init__(self, headers: Dict[str, str]):
self._headers = headers
Expand Down
25 changes: 18 additions & 7 deletions potassium/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,36 @@ class Worker():
stderr_redirect: FDRedirect
stdout_redirect: FDRedirect


def init_worker(index_queue, event_queue, response_queue, init_func):
global worker
worker_num = index_queue.get()

stdout_redirect = FDRedirect(1)
stderr_redirect = FDRedirect(2)

stderr_redirect.set_prefix(f"[worker {worker_num}] ")
stdout_redirect.set_prefix(f"[worker {worker_num}] ")

# check if the init function takes in a worker number
if len(inspect.signature(init_func).parameters) == 0:
context = init_func()
else:
context = init_func(worker_num)
try:
if len(inspect.signature(init_func).parameters) == 0:
context = init_func()
else:
context = init_func(worker_num)
except Exception as e:
tb_str = traceback.format_exc()
print(colored(tb_str, "red"))
raise e


event_queue.put((StatusEvent.WORKER_STARTED,))

worker = Worker(
context,
event_queue,
response_queue,
FDRedirect(1),
FDRedirect(2)
stdout_redirect,
stderr_redirect
)

def run_worker(func, request, internal_id, use_response=False):
Expand Down

0 comments on commit 9241ee9

Please sign in to comment.