From 5a8a867ef66b6e1807b3f371023dceb6708edbc4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 31 Jul 2023 15:07:19 -0700 Subject: [PATCH] Fix max_total_tokens calculation --- aviary/backend/llm/continuous/scheduler.py | 14 +++++++++----- .../llm/predictor/continuous_batching_predictor.py | 13 ++++++------- aviary/backend/server/models.py | 1 - 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/aviary/backend/llm/continuous/scheduler.py b/aviary/backend/llm/continuous/scheduler.py index d365812c..8b84a6c3 100644 --- a/aviary/backend/llm/continuous/scheduler.py +++ b/aviary/backend/llm/continuous/scheduler.py @@ -215,9 +215,16 @@ def process_request( self, input_text: str, params: Dict[str, Any], - max_new_tokens: int = 256, + max_new_tokens: Optional[int] = 256, max_length: int = 1024, + max_total_tokens: int = 1024, ) -> TokenStream: + request_input_length = self._tokenizer.get_input_length(input_text) + upper_max_new_tokens = max_total_tokens - request_input_length + if max_new_tokens is None: + max_new_tokens = max_total_tokens - request_input_length + else: + max_new_tokens = min(max_new_tokens, upper_max_new_tokens) request = Request( id=get_request_id(), inputs=input_text, @@ -225,12 +232,9 @@ def process_request( max_new_tokens=max_new_tokens, params=params, ) - return self._add_request(request) - - def _add_request(self, request: Request) -> TokenStream: pending_request = InferenceRequest.from_request( request, - request_input_length=self._tokenizer.get_input_length(request.inputs), + request_input_length=request_input_length, ) self._request_queue.put_nowait(pending_request) self._queue_put_event.set() diff --git a/aviary/backend/llm/predictor/continuous_batching_predictor.py b/aviary/backend/llm/predictor/continuous_batching_predictor.py index 6f9cf7ae..aea56398 100644 --- a/aviary/backend/llm/predictor/continuous_batching_predictor.py +++ b/aviary/backend/llm/predictor/continuous_batching_predictor.py @@ -318,15 +318,17 @@ async def _create_worker_group( return worker_group def process_request( - self, prompt: str, max_new_tokens: int, sampling_params: Dict[str, Any] + self, + prompt: str, + max_new_tokens: Optional[int], + sampling_params: Dict[str, Any], ): - # TODO improve error message - assert max_new_tokens + self.max_input_length <= self.max_total_tokens return self.scheduler.process_request( prompt, sampling_params, max_new_tokens=max_new_tokens, max_length=self.max_input_length, + max_total_tokens=self.max_total_tokens, ) def validate_prompt(self, prompt: Prompt) -> None: @@ -374,10 +376,7 @@ async def _stream_async( generate_kwargs = merge_dicts( prompt.parameters or {}, model_config.generation.generate_kwargs ) - max_new_tokens = min( - generate_kwargs.get("max_new_tokens", 512), - self.max_total_tokens - self.max_input_length, - ) + max_new_tokens = generate_kwargs.get("max_new_tokens", None) result = self.process_request( prompt_text, max_new_tokens=max_new_tokens, diff --git a/aviary/backend/server/models.py b/aviary/backend/server/models.py index 3ebdfc4a..077cafb7 100644 --- a/aviary/backend/server/models.py +++ b/aviary/backend/server/models.py @@ -427,7 +427,6 @@ class ContinuousBatchingInitializationConfig(InitializationConfig): class GenerationConfig(BaseModelExtended): prompt_format: PromptFormat generate_kwargs: Dict[str, Any] = { - "max_new_tokens": 256, "do_sample": True, "top_p": 0.92, "top_k": 0,