From 07104b18b7d523789142a357a5e5117146b4eef3 Mon Sep 17 00:00:00 2001 From: yangwenz Date: Tue, 2 Jul 2024 22:26:10 +0800 Subject: [PATCH] Support streaming APIs --- examples/streaming/model.py | 26 ++++++++++++++++++++++++ examples/streaming/test.py | 6 ++++++ kservehelper/kserve/model_server.py | 2 +- kservehelper/kserve/rest/server.py | 2 ++ kservehelper/kserve/rest/v1_endpoints.py | 8 ++++++++ kservehelper/model.py | 18 ++++++++++++++++ 6 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 examples/streaming/model.py create mode 100644 examples/streaming/test.py diff --git a/examples/streaming/model.py b/examples/streaming/model.py new file mode 100644 index 0000000..651ee20 --- /dev/null +++ b/examples/streaming/model.py @@ -0,0 +1,26 @@ +import time +from kservehelper.model import KServeModel +from kservehelper.types import Input + + +class Model: + + def load(self): + pass + + def generate( + self, + repeat: int = Input( + description="The number of repeats", + default=5 + ) + ): + def _generator(): + for _ in range(repeat): + yield "Hello World!" + time.sleep(1) + return _generator + + +if __name__ == "__main__": + KServeModel.serve("streaming", Model) diff --git a/examples/streaming/test.py b/examples/streaming/test.py new file mode 100644 index 0000000..8041395 --- /dev/null +++ b/examples/streaming/test.py @@ -0,0 +1,6 @@ +import requests + +url = "http://localhost:8080/v1/models/streaming:generate" +with requests.post(url, stream=True, json={"repeat": 5}) as r: + for chunk in r.iter_content(16): + print(chunk) diff --git a/kservehelper/kserve/model_server.py b/kservehelper/kserve/model_server.py index 577b65f..8c9114e 100644 --- a/kservehelper/kserve/model_server.py +++ b/kservehelper/kserve/model_server.py @@ -92,7 +92,7 @@ async def generate( body = self.decode(body, headers) model = self.get_model(model_name) - response = await model.generate(body, headers=headers) + response = model.generate(body, headers=headers) return response, headers diff --git a/kservehelper/kserve/rest/server.py b/kservehelper/kserve/rest/server.py index 5ed2127..9adfe2f 100644 --- a/kservehelper/kserve/rest/server.py +++ b/kservehelper/kserve/rest/server.py @@ -88,6 +88,8 @@ def create_application(self) -> FastAPI: FastAPIRoute(r"/v1/models/{model_name}", v1_endpoints.model_ready, tags=["V1"]), FastAPIRoute(r"/v1/models/{model_name}:predict", v1_endpoints.predict, methods=["POST"], tags=["V1"]), + FastAPIRoute(r"/v1/models/{model_name}:generate", + v1_endpoints.generate, methods=["POST"], tags=["V1"]), FastAPIRoute(r"/v1/models/{model_name}:explain", v1_endpoints.explain, methods=["POST"], tags=["V1"]), # Model docs diff --git a/kservehelper/kserve/rest/v1_endpoints.py b/kservehelper/kserve/rest/v1_endpoints.py index ee7ff81..45e1f24 100644 --- a/kservehelper/kserve/rest/v1_endpoints.py +++ b/kservehelper/kserve/rest/v1_endpoints.py @@ -14,6 +14,7 @@ from typing import Optional, Union, Dict, List from fastapi import Request, Response +from fastapi.responses import StreamingResponse from kserve.errors import ModelNotReady from kserve.protocol.dataplane import DataPlane @@ -72,6 +73,13 @@ async def predict(self, model_name: str, request: Request) -> Union[Response, Di return Response(content=response, headers=response_headers) return response + async def generate(self, model_name: str, request: Request) -> StreamingResponse: + body = await request.body() + headers = dict(request.headers.items()) + results_generator, response_headers = \ + await self.dataplane.generate(model_name=model_name, body=body, headers=headers) + return StreamingResponse(results_generator()) + async def explain(self, model_name: str, request: Request) -> Union[Response, Dict]: """Explain handler. diff --git a/kservehelper/model.py b/kservehelper/model.py index e40a672..4a67bfc 100644 --- a/kservehelper/model.py +++ b/kservehelper/model.py @@ -150,6 +150,17 @@ def _build_functions(model_class): setattr(KServeModel, "predict", KServeModel._predict) KServeModel.HAS_PREDICT = True + # Streaming generation + # Note that the streaming generation method doesn't support preprocess or postprocess function + method = getattr(model_class, "generate", None) + if callable(method): + if KServeModel.HAS_PREDICT: + raise ValueError("The model can only have one of `predict` and `generate` methods") + KServeModel.MODEL_IO_INFO.set_input_signatures(method) + KServeModel.MODEL_IO_INFO.set_output_signatures(method) + setattr(KServeModel, "generate", KServeModel._generate) + KServeModel.HAS_PREDICT = True + # Preprocess function method = getattr(model_class, "preprocess", None) if callable(method): @@ -211,6 +222,13 @@ def _predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict: results["running_time"] = f"{time.time() - start_time}s" return results + @staticmethod + def _generate(self, payload: Dict, headers: Dict[str, str] = None): + payload.pop("upload_webhook", None) + payload = KServeModel._process_payload(payload) + generator = self.model.generate(**payload) + return generator + @staticmethod def _preprocess(self, payload: Dict, headers: Dict[str, str] = None) -> Dict: self.predict_start_time = time.time()