diff --git a/aviary/api/sdk.py b/aviary/api/sdk.py index 931e7032..dbdbc337 100644 --- a/aviary/api/sdk.py +++ b/aviary/api/sdk.py @@ -3,7 +3,7 @@ import requests -from aviary.common.constants import TIMEOUT +from aviary.common.constants import TIMEOUT, DEFAULT_API_VERSION from aviary.api.utils import ( AviaryBackend, BackendError, @@ -14,7 +14,6 @@ _convert_to_aviary_format, ) - __all__ = ["models", "metadata", "completions", "batch_completions", "run", "get_aviary_backend"] @@ -45,25 +44,26 @@ def get_aviary_backend(): return AviaryBackend(aviary_url, bearer) -def models() -> List[str]: +def models(version: str = DEFAULT_API_VERSION) -> List[str]: """List available models""" backend = get_aviary_backend() - url = backend.backend_url + "models" - response = requests.get(url, headers=backend.header, timeout=TIMEOUT) + request_url = backend.backend_url + version + "/models" + response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT) try: result = response.json() except requests.JSONDecodeError as e: raise BackendError( - f"Error decoding JSON from {url}. Text response: {response.text}", + f"Error decoding JSON from {request_url}. Text response: {response.text}", response=response, ) from e return result -def metadata(model_id: str) -> Dict[str, Dict[str, Any]]: +def metadata(model_id: str, + version: str = DEFAULT_API_VERSION) -> Dict[str, Dict[str, Any]]: """Get model metadata""" backend = get_aviary_backend() - url = backend.backend_url + "metadata/" + model_id.replace("/", "--") + url = backend.backend_url + version + "/metadata/" + model_id.replace("/", "--") response = requests.get(url, headers=backend.header, timeout=TIMEOUT) try: result = response.json() @@ -75,12 +75,16 @@ def metadata(model_id: str) -> Dict[str, Dict[str, Any]]: return result -def completions(model: str, prompt: str) -> Dict[str, Union[str, float, int]]: +def completions( + model: str, + prompt: str, + version: str = DEFAULT_API_VERSION +) -> Dict[str, Union[str, float, int]]: """Query Aviary""" if _is_aviary_model(model): backend = get_aviary_backend() - url = backend.backend_url + "query/" + model.replace("/", "--") + url = backend.backend_url + version + "/query/" + model.replace("/", "--") response = requests.post( url, headers=backend.header, @@ -99,13 +103,15 @@ def completions(model: str, prompt: str) -> Dict[str, Union[str, float, int]]: def batch_completions( - model: str, prompts: List[str] + model: str, + prompts: List[str], + version: str = DEFAULT_API_VERSION ) -> List[Dict[str, Union[str, float, int]]]: """Batch Query Aviary""" if _is_aviary_model(model): backend = get_aviary_backend() - url = backend.backend_url + "query/batch/" + model.replace("/", "--") + url = backend.backend_url + version + "/query/batch/" + model.replace("/", "--") response = requests.post( url, headers=backend.header, diff --git a/aviary/backend/server/app.py b/aviary/backend/server/app.py index 4061798b..22e1c1fb 100644 --- a/aviary/backend/server/app.py +++ b/aviary/backend/server/app.py @@ -7,6 +7,8 @@ import ray import ray.util from fastapi import FastAPI +from fastapi_versioning import VersionedFastAPI, version + from ray import serve from ray.exceptions import RayActorError from ray.serve.deployment import ClassNode @@ -123,6 +125,7 @@ async def validate_prompt(self, prompt: Prompt) -> None: ) @app.get("/metadata", include_in_schema=False) + @version(0) async def metadata(self) -> dict: return self.args.dict( exclude={ @@ -131,6 +134,7 @@ async def metadata(self) -> dict: ) @app.post("/", include_in_schema=False) + @version(0) async def generate_text(self, prompt: Prompt): await self.validate_prompt(prompt) time.time() @@ -143,6 +147,7 @@ async def generate_text(self, prompt: Prompt): return text @app.post("/batch", include_in_schema=False) + @version(0) async def batch_generate_text(self, prompts: List[Prompt]): for prompt in prompts: await self.validate_prompt(prompt) @@ -265,6 +270,7 @@ def __init__( self._model_configurations = model_configurations @app.post("/query/{model}") + @version(0) async def query(self, model: str, prompt: Prompt) -> Dict[str, Dict[str, Any]]: model = _replace_prefix(model) results = await asyncio.gather( @@ -275,6 +281,7 @@ async def query(self, model: str, prompt: Prompt) -> Dict[str, Dict[str, Any]]: return {model: results} @app.post("/query/batch/{model}") + @version(0) async def batch_query( self, model: str, prompts: List[Prompt] ) -> Dict[str, List[Dict[str, Any]]]: @@ -291,6 +298,7 @@ async def batch_query( return {model: results} @app.get("/metadata/{model}") + @version(0) async def metadata(self, model) -> Dict[str, Dict[str, Any]]: model = _replace_prefix(model) # This is what we want to do eventually, but it looks like reconfigure is blocking @@ -308,5 +316,13 @@ async def metadata(self, model) -> Dict[str, Dict[str, Any]]: return {"metadata": metadata} @app.get("/models") + @version(0) async def models(self) -> List[str]: return list(self._models.keys()) + + +app = VersionedFastAPI( + app, + version_format='{major}', + prefix_format='/v{major}' +) diff --git a/aviary/common/constants.py b/aviary/common/constants.py index 2f94b968..22c930e7 100644 --- a/aviary/common/constants.py +++ b/aviary/common/constants.py @@ -1,5 +1,7 @@ PROJECT_NAME = "AviaryFrontend" +DEFAULT_API_VERSION = "v0" + NUM_LLM_OPTIONS = 3 # AWS timeout diff --git a/setup.py b/setup.py index ca4d6623..ca838593 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "pydantic==1.10.7", "einops", "markdown-it-py[plugins]", + "fastapi-versioning", ], "frontend": [ "gradio",