Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Batch request API (#25)
Browse files Browse the repository at this point in the history
Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored May 23, 2023
1 parent 00c7b1f commit 0340b3e
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 8 deletions.
11 changes: 11 additions & 0 deletions aviary/api/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from aviary.api import sdk
import typer
from typing import List

app = typer.Typer()

Expand All @@ -15,6 +16,16 @@ def query(model: str, prompt: str):
sdk.query(model, prompt)


@app.command(name="batch_query")
def batch_query(model: str, prompts: List[str]):
""" Query a model
Args:
model (str): The model to query. The model must already be running.
prompts (List[str]): The prompts to use to query the model.
"""
sdk.batch_query(model, prompts)

@app.command(name="run")
def run(*model: str):
sdk.run(*model)
Expand Down
6 changes: 6 additions & 0 deletions aviary/api/sdk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import List

from aviary.api.env import assert_has_backend


def query(model: str, prompt: str):
"""Query Aviary"""
print("Querying ", model, prompt)

def batch_query(model: str, prompts: List[str]):
"""Batch Query Aviary"""
print("Batch Querying ", model, prompts)


def run(*model: str):
"""Run Aviary on the local ray cluster"""
Expand Down
83 changes: 80 additions & 3 deletions aviary/backend/server/_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
from dataclasses import dataclass, field
from enum import IntEnum
from functools import wraps
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple, Type

from ray.serve.batching import (
_BatchQueue,
Expand Down Expand Up @@ -37,6 +40,69 @@ def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[obj
return None


class QueuePriority(IntEnum):
"""Lower value = higher priority"""

GENERATE_TEXT = 0
BATCH_GENERATE_TEXT = 1


@dataclass(order=True)
class _PriorityWrapper:
"""Wrapper allowing for priority queueing of arbitrary objects."""

obj: Any = field(compare=False)
priority: int = field(compare=True)


class PriorityQueueWithUnwrap(asyncio.PriorityQueue):
def get_nowait(self) -> Any:
# Get just the obj from _PriorityWrapper
ret: _PriorityWrapper = super().get_nowait()
return ret.obj


class _PriorityBatchQueue(_BatchQueue):
# The kwarg of the batch function used to determine priority.
_priority_kwarg: str = "priority"

def __init__(
self,
max_batch_size: int,
timeout_s: float,
handle_batch_func: Optional[Callable] = None,
) -> None:
"""Async queue that accepts individual items and returns batches.
Compared to base _BatchQueue, this class uses asyncio.PriorityQueue.
Respects max_batch_size and timeout_s; a batch will be returned when
max_batch_size elements are available or the timeout has passed since
the previous get.
If handle_batch_func is passed in, a background coroutine will run to
poll from the queue and call handle_batch_func on the results.
Arguments:
max_batch_size: max number of elements to return in a batch.
timeout_s: time to wait before returning an incomplete
batch.
handle_batch_func(Optional[Callable]): callback to run in the
background to handle batches if provided.
"""
super().__init__(max_batch_size, timeout_s, handle_batch_func)
self.queue: PriorityQueueWithUnwrap[_SingleRequest] = PriorityQueueWithUnwrap()

def put(
self,
request: Tuple[_SingleRequest, asyncio.Future],
*,
priority: int,
) -> None:
# Lower index = higher priority
super().put(_PriorityWrapper(obj=request, priority=int(priority)))


def _validate_max_batch_size(max_batch_size):
if not isinstance(max_batch_size, int):
if isinstance(max_batch_size, float) and max_batch_size.is_integer():
Expand All @@ -60,6 +126,8 @@ def batch(
_func: Optional[Callable] = None,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
*,
batch_queue_cls: Type[_BatchQueue] = _BatchQueue,
):
"""Converts a function to asynchronously handle batches.
Expand Down Expand Up @@ -99,6 +167,7 @@ async def __call__(self, request: Request):
one call to the underlying function.
batch_wait_timeout_s: the maximum duration to wait for
`max_batch_size` elements before running the current batch.
batch_queue_cls: the class to use for the batch queue.
"""
# `_func` will be None in the case when the decorator is parametrized.
# See the comment at the end of this function for a detailed explanation.
Expand All @@ -120,6 +189,10 @@ async def __call__(self, request: Request):
def _batch_decorator(_func):
@wraps(_func)
async def batch_wrapper(*args, **kwargs):
priority_kwarg = getattr(batch_queue_cls, "_priority_kwarg", None)
priority_kwargs = {}
if priority_kwarg:
priority_kwargs = {priority_kwarg: kwargs.pop(priority_kwarg)}
self = extract_self_if_method_call(args, _func)
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)

Expand All @@ -139,7 +212,9 @@ async def batch_wrapper(*args, **kwargs):
# runs, we just get a reference to the attribute.
batch_queue_attr = f"__serve_batch_queue_{_func.__name__}"
if not hasattr(batch_queue_object, batch_queue_attr):
batch_queue = _BatchQueue(max_batch_size, batch_wait_timeout_s, _func)
batch_queue = batch_queue_cls(
max_batch_size, batch_wait_timeout_s, _func
)
setattr(batch_queue_object, batch_queue_attr, batch_queue)
else:
batch_queue = getattr(batch_queue_object, batch_queue_attr)
Expand All @@ -155,7 +230,9 @@ async def batch_wrapper(*args, **kwargs):
batch_queue.timeout_s = new_batch_wait_timeout_s

future = get_or_create_event_loop().create_future()
batch_queue.put(_SingleRequest(self, flattened_args, future))
batch_queue.put(
_SingleRequest(self, flattened_args, future), **priority_kwargs
)

# This will raise if the underlying call raised an exception.
return await future
Expand Down
35 changes: 33 additions & 2 deletions aviary/backend/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from aviary.backend.llm.predictor import LLMPredictor
from aviary.backend.logger import get_logger
from aviary.backend.server._batch import batch
from aviary.backend.server._batch import QueuePriority, _PriorityBatchQueue, batch
from aviary.backend.server.models import Args, DeepSpeed, Prompt

logger = get_logger(__name__)
Expand Down Expand Up @@ -41,9 +41,23 @@ def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:

@app.post("/")
async def generate_text(self, prompt: Prompt):
text = await self.generate_text_batch(prompt)
text = await self.generate_text_batch(
prompt, priority=QueuePriority.GENERATE_TEXT
)
return text

@app.post("/batch")
async def batch_generate_text(self, prompts: List[Prompt]):
texts = await asyncio.gather(
*[
self.generate_text_batch(
prompt, priority=QueuePriority.BATCH_GENERATE_TEXT
)
for prompt in prompts
]
)
return texts

@app.get("/metadata")
async def metadata(self) -> dict:
return self.args.dict(
Expand Down Expand Up @@ -132,6 +146,7 @@ def get_batch_wait_timeout_s(self):
@batch(
max_batch_size=get_max_batch_size,
batch_wait_timeout_s=get_batch_wait_timeout_s,
batch_queue_cls=_PriorityBatchQueue,
)
async def generate_text_batch(self, prompts: List[Prompt]):
"""Generate text from the given prompts in batch"""
Expand Down Expand Up @@ -212,6 +227,22 @@ async def query(self, model: str, prompt: Prompt) -> Dict[str, Dict[str, Any]]:
logger.info(prompts)
return {model: prompts}

@app.post("/query/batch/{model}")
async def batch_query(
self, model: str, prompts: List[Prompt]
) -> Dict[str, List[Dict[str, Any]]]:
model = model.replace("--", "/")
prompts = await asyncio.gather(
*(
await asyncio.gather(
*[self._models[model].batch_generate_text.remote(prompts)]
)
)
)
prompts = prompts[0]
logger.info(prompts)
return {model: prompts}

@app.get("/metadata/{model}")
async def metadata(self, model) -> Dict[str, Dict[str, Any]]:
model = model.replace("--", "/")
Expand Down
5 changes: 3 additions & 2 deletions run_on_every_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
import os
import subprocess

import ray
import ray.util.scheduling_strategies
import subprocess
import argparse


def force_on_node(node_id: str, remote_func_or_actor_class):
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
packages=find_packages(include="aviary*"),
include_package_data=True,
package_data={"aviary": ["models/*"]},
entry_points={"console_scripts": ["aviary=aviary.api.cli:app",]},
entry_points={
"console_scripts": [
"aviary=aviary.api.cli:app",
]
},
install_requires=["typer>=0.9"],
extras_require={
# TODO(tchordia): test whether this works, and determine how we can keep requirements
Expand Down

0 comments on commit 0340b3e

Please sign in to comment.