Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #35 from ray-project/master
Browse files Browse the repository at this point in the history
Fix max_total_tokens calculation
  • Loading branch information
akshay-anyscale authored Aug 1, 2023
2 parents 07bc2ae + 86178d0 commit bcaffd4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
14 changes: 9 additions & 5 deletions aviary/backend/llm/continuous/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,26 @@ def process_request(
self,
input_text: str,
params: Dict[str, Any],
max_new_tokens: int = 256,
max_new_tokens: Optional[int] = 256,
max_length: int = 1024,
max_total_tokens: int = 1024,
) -> TokenStream:
request_input_length = self._tokenizer.get_input_length(input_text)
upper_max_new_tokens = max_total_tokens - request_input_length
if max_new_tokens is None:
max_new_tokens = max_total_tokens - request_input_length
else:
max_new_tokens = min(max_new_tokens, upper_max_new_tokens)
request = Request(
id=get_request_id(),
inputs=input_text,
truncate=max_length,
max_new_tokens=max_new_tokens,
params=params,
)
return self._add_request(request)

def _add_request(self, request: Request) -> TokenStream:
pending_request = InferenceRequest.from_request(
request,
request_input_length=self._tokenizer.get_input_length(request.inputs),
request_input_length=request_input_length,
)
self._request_queue.put_nowait(pending_request)
self._queue_put_event.set()
Expand Down
13 changes: 6 additions & 7 deletions aviary/backend/llm/predictor/continuous_batching_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,17 @@ async def _create_worker_group(
return worker_group

def process_request(
self, prompt: str, max_new_tokens: int, sampling_params: Dict[str, Any]
self,
prompt: str,
max_new_tokens: Optional[int],
sampling_params: Dict[str, Any],
):
# TODO improve error message
assert max_new_tokens + self.max_input_length <= self.max_total_tokens
return self.scheduler.process_request(
prompt,
sampling_params,
max_new_tokens=max_new_tokens,
max_length=self.max_input_length,
max_total_tokens=self.max_total_tokens,
)

def validate_prompt(self, prompt: Prompt) -> None:
Expand Down Expand Up @@ -374,10 +376,7 @@ async def _stream_async(
generate_kwargs = merge_dicts(
prompt.parameters or {}, model_config.generation.generate_kwargs
)
max_new_tokens = min(
generate_kwargs.get("max_new_tokens", 512),
self.max_total_tokens - self.max_input_length,
)
max_new_tokens = generate_kwargs.get("max_new_tokens", None)
result = self.process_request(
prompt_text,
max_new_tokens=max_new_tokens,
Expand Down
1 change: 0 additions & 1 deletion aviary/backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ class ContinuousBatchingInitializationConfig(InitializationConfig):
class GenerationConfig(BaseModelExtended):
prompt_format: PromptFormat
generate_kwargs: Dict[str, Any] = {
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.92,
"top_k": 0,
Expand Down

0 comments on commit bcaffd4

Please sign in to comment.