Skip to content

Commit

Permalink
feat: add o1-series models support in Agent App (ReACT only) (#8350)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Sep 13, 2024
1 parent 8d2269f commit 4637dda
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def _chat_generate(
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]

if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]

# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
Expand All @@ -635,14 +638,15 @@ def _chat_generate(
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)

if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)

return block_result

def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
Expand All @@ -652,15 +656,22 @@ def _handle_chat_block_as_stream_response(
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)

if stop:
text = self.enforce_stop_tokens(text, stop)

yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
Expand Down Expand Up @@ -912,6 +923,20 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp
]
)

if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)

new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages

return prompt_messages

def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
Expand Down

0 comments on commit 4637dda

Please sign in to comment.