Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
versioned api endpoints (#126)
Browse files Browse the repository at this point in the history
Signed-off-by: Max Pumperla <[email protected]>
  • Loading branch information
maxpumperla committed Jun 14, 2023
1 parent e1b2612 commit 87ab81f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
30 changes: 18 additions & 12 deletions aviary/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import requests

from aviary.common.constants import TIMEOUT
from aviary.common.constants import TIMEOUT, DEFAULT_API_VERSION
from aviary.api.utils import (
AviaryBackend,
BackendError,
Expand All @@ -14,7 +14,6 @@
_convert_to_aviary_format,
)


__all__ = ["models", "metadata", "completions", "batch_completions", "run",
"get_aviary_backend"]

Expand Down Expand Up @@ -45,25 +44,26 @@ def get_aviary_backend():
return AviaryBackend(aviary_url, bearer)


def models() -> List[str]:
def models(version: str = DEFAULT_API_VERSION) -> List[str]:
"""List available models"""
backend = get_aviary_backend()
url = backend.backend_url + "models"
response = requests.get(url, headers=backend.header, timeout=TIMEOUT)
request_url = backend.backend_url + version + "/models"
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
except requests.JSONDecodeError as e:
raise BackendError(
f"Error decoding JSON from {url}. Text response: {response.text}",
f"Error decoding JSON from {request_url}. Text response: {response.text}",
response=response,
) from e
return result


def metadata(model_id: str) -> Dict[str, Dict[str, Any]]:
def metadata(model_id: str,
version: str = DEFAULT_API_VERSION) -> Dict[str, Dict[str, Any]]:
"""Get model metadata"""
backend = get_aviary_backend()
url = backend.backend_url + "metadata/" + model_id.replace("/", "--")
url = backend.backend_url + version + "/metadata/" + model_id.replace("/", "--")
response = requests.get(url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
Expand All @@ -75,12 +75,16 @@ def metadata(model_id: str) -> Dict[str, Dict[str, Any]]:
return result


def completions(model: str, prompt: str) -> Dict[str, Union[str, float, int]]:
def completions(
model: str,
prompt: str,
version: str = DEFAULT_API_VERSION
) -> Dict[str, Union[str, float, int]]:
"""Query Aviary"""

if _is_aviary_model(model):
backend = get_aviary_backend()
url = backend.backend_url + "query/" + model.replace("/", "--")
url = backend.backend_url + version + "/query/" + model.replace("/", "--")
response = requests.post(
url,
headers=backend.header,
Expand All @@ -99,13 +103,15 @@ def completions(model: str, prompt: str) -> Dict[str, Union[str, float, int]]:


def batch_completions(
model: str, prompts: List[str]
model: str,
prompts: List[str],
version: str = DEFAULT_API_VERSION
) -> List[Dict[str, Union[str, float, int]]]:
"""Batch Query Aviary"""

if _is_aviary_model(model):
backend = get_aviary_backend()
url = backend.backend_url + "query/batch/" + model.replace("/", "--")
url = backend.backend_url + version + "/query/batch/" + model.replace("/", "--")
response = requests.post(
url,
headers=backend.header,
Expand Down
16 changes: 16 additions & 0 deletions aviary/backend/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import ray
import ray.util
from fastapi import FastAPI
from fastapi_versioning import VersionedFastAPI, version

from ray import serve
from ray.exceptions import RayActorError
from ray.serve.deployment import ClassNode
Expand Down Expand Up @@ -123,6 +125,7 @@ async def validate_prompt(self, prompt: Prompt) -> None:
)

@app.get("/metadata", include_in_schema=False)
@version(0)
async def metadata(self) -> dict:
return self.args.dict(
exclude={
Expand All @@ -131,6 +134,7 @@ async def metadata(self) -> dict:
)

@app.post("/", include_in_schema=False)
@version(0)
async def generate_text(self, prompt: Prompt):
await self.validate_prompt(prompt)
time.time()
Expand All @@ -143,6 +147,7 @@ async def generate_text(self, prompt: Prompt):
return text

@app.post("/batch", include_in_schema=False)
@version(0)
async def batch_generate_text(self, prompts: List[Prompt]):
for prompt in prompts:
await self.validate_prompt(prompt)
Expand Down Expand Up @@ -265,6 +270,7 @@ def __init__(
self._model_configurations = model_configurations

@app.post("/query/{model}")
@version(0)
async def query(self, model: str, prompt: Prompt) -> Dict[str, Dict[str, Any]]:
model = _replace_prefix(model)
results = await asyncio.gather(
Expand All @@ -275,6 +281,7 @@ async def query(self, model: str, prompt: Prompt) -> Dict[str, Dict[str, Any]]:
return {model: results}

@app.post("/query/batch/{model}")
@version(0)
async def batch_query(
self, model: str, prompts: List[Prompt]
) -> Dict[str, List[Dict[str, Any]]]:
Expand All @@ -291,6 +298,7 @@ async def batch_query(
return {model: results}

@app.get("/metadata/{model}")
@version(0)
async def metadata(self, model) -> Dict[str, Dict[str, Any]]:
model = _replace_prefix(model)
# This is what we want to do eventually, but it looks like reconfigure is blocking
Expand All @@ -308,5 +316,13 @@ async def metadata(self, model) -> Dict[str, Dict[str, Any]]:
return {"metadata": metadata}

@app.get("/models")
@version(0)
async def models(self) -> List[str]:
return list(self._models.keys())


app = VersionedFastAPI(
app,
version_format='{major}',
prefix_format='/v{major}'
)
2 changes: 2 additions & 0 deletions aviary/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
PROJECT_NAME = "AviaryFrontend"

DEFAULT_API_VERSION = "v0"

NUM_LLM_OPTIONS = 3

# AWS timeout
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"pydantic==1.10.7",
"einops",
"markdown-it-py[plugins]",
"fastapi-versioning",
],
"frontend": [
"gradio",
Expand Down

0 comments on commit 87ab81f

Please sign in to comment.