Skip to content

Commit

Permalink
Support streaming APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
yangwenz committed Jul 2, 2024
1 parent 8d48526 commit 07104b1
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 1 deletion.
26 changes: 26 additions & 0 deletions examples/streaming/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import time
from kservehelper.model import KServeModel
from kservehelper.types import Input


class Model:

def load(self):
pass

def generate(
self,
repeat: int = Input(
description="The number of repeats",
default=5
)
):
def _generator():
for _ in range(repeat):
yield "Hello World!"
time.sleep(1)
return _generator


if __name__ == "__main__":
KServeModel.serve("streaming", Model)
6 changes: 6 additions & 0 deletions examples/streaming/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import requests

url = "http://localhost:8080/v1/models/streaming:generate"
with requests.post(url, stream=True, json={"repeat": 5}) as r:
for chunk in r.iter_content(16):
print(chunk)
2 changes: 1 addition & 1 deletion kservehelper/kserve/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def generate(
body = self.decode(body, headers)

model = self.get_model(model_name)
response = await model.generate(body, headers=headers)
response = model.generate(body, headers=headers)
return response, headers


Expand Down
2 changes: 2 additions & 0 deletions kservehelper/kserve/rest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def create_application(self) -> FastAPI:
FastAPIRoute(r"/v1/models/{model_name}", v1_endpoints.model_ready, tags=["V1"]),
FastAPIRoute(r"/v1/models/{model_name}:predict",
v1_endpoints.predict, methods=["POST"], tags=["V1"]),
FastAPIRoute(r"/v1/models/{model_name}:generate",
v1_endpoints.generate, methods=["POST"], tags=["V1"]),
FastAPIRoute(r"/v1/models/{model_name}:explain",
v1_endpoints.explain, methods=["POST"], tags=["V1"]),
# Model docs
Expand Down
8 changes: 8 additions & 0 deletions kservehelper/kserve/rest/v1_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional, Union, Dict, List

from fastapi import Request, Response
from fastapi.responses import StreamingResponse

from kserve.errors import ModelNotReady
from kserve.protocol.dataplane import DataPlane
Expand Down Expand Up @@ -72,6 +73,13 @@ async def predict(self, model_name: str, request: Request) -> Union[Response, Di
return Response(content=response, headers=response_headers)
return response

async def generate(self, model_name: str, request: Request) -> StreamingResponse:
body = await request.body()
headers = dict(request.headers.items())
results_generator, response_headers = \
await self.dataplane.generate(model_name=model_name, body=body, headers=headers)
return StreamingResponse(results_generator())

async def explain(self, model_name: str, request: Request) -> Union[Response, Dict]:
"""Explain handler.
Expand Down
18 changes: 18 additions & 0 deletions kservehelper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def _build_functions(model_class):
setattr(KServeModel, "predict", KServeModel._predict)
KServeModel.HAS_PREDICT = True

# Streaming generation
# Note that the streaming generation method doesn't support preprocess or postprocess function
method = getattr(model_class, "generate", None)
if callable(method):
if KServeModel.HAS_PREDICT:
raise ValueError("The model can only have one of `predict` and `generate` methods")
KServeModel.MODEL_IO_INFO.set_input_signatures(method)
KServeModel.MODEL_IO_INFO.set_output_signatures(method)
setattr(KServeModel, "generate", KServeModel._generate)
KServeModel.HAS_PREDICT = True

# Preprocess function
method = getattr(model_class, "preprocess", None)
if callable(method):
Expand Down Expand Up @@ -211,6 +222,13 @@ def _predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
results["running_time"] = f"{time.time() - start_time}s"
return results

@staticmethod
def _generate(self, payload: Dict, headers: Dict[str, str] = None):
payload.pop("upload_webhook", None)
payload = KServeModel._process_payload(payload)
generator = self.model.generate(**payload)
return generator

@staticmethod
def _preprocess(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
self.predict_start_time = time.time()
Expand Down

0 comments on commit 07104b1

Please sign in to comment.