diff --git a/docassemble/ALToolbox/llms.py b/docassemble/ALToolbox/llms.py index dd724b4..9a94d74 100644 --- a/docassemble/ALToolbox/llms.py +++ b/docassemble/ALToolbox/llms.py @@ -120,6 +120,8 @@ def chat_completion( model: str = "gpt-3.5-turbo", messages: Optional[List[Dict[str, str]]] = None, skip_moderation: bool = False, + max_input_tokens: int = 4096, + max_output_tokens: int = 4096, ) -> Union[List[Any], Dict[str, Any], str]: """A light wrapper on the OpenAI chat endpoint. @@ -140,6 +142,8 @@ def chat_completion( model (str): The model to use for the GPT API messages (Optional[List[Dict[str, str]]]): A list of messages to send to the chat engine. If provided, system_message and user_message will be ignored. skip_moderation (bool): Whether to skip the OpenAI moderation step, which may save seconds but risks banning your account. Only enable when you have full control over the inputs. + max_input_tokens (int): The maximum number of tokens to allow in the input. Defaults to 4096. If not provided, will try to use the (as of last update) model maximums + max_output_tokens (int): The maximum number of tokens to allow in the output. Defaults to 4096. If not provided, will try to use the (as of last update) model maximums Returns: A string with the response from the API endpoint or JSON data if json_mode is True @@ -194,15 +198,16 @@ def chat_completion( encoding = tiktoken.encoding_for_model(model) token_count = len(encoding.encode(str(messages))) - if model.startswith("gpt-4-"): # E.g., "gpt-4-turbo" - max_input_tokens = 128000 - max_output_tokens = 4096 - elif model.startswith("gpt-3.5-turbo"): - max_input_tokens = 16385 - max_output_tokens = 4096 - else: - max_input_tokens = 4096 - max_output_tokens = 4096 - token_count - 100 # small safety margin + if not max_output_tokens and not max_input_tokens: + if model.startswith("gpt-4-") or model.startswith("gpt-4o"): # E.g., "gpt-4-turbo" + max_input_tokens = 128000 + max_output_tokens = 4096 + elif model.startswith("gpt-3.5-turbo"): + max_input_tokens = 16385 + max_output_tokens = 4096 + else: + max_input_tokens = 4096 + max_output_tokens = 4096 - token_count - 100 # small safety margin if token_count > max_input_tokens: raise Exception(