Skip to content

Commit

Permalink
fix(backend): update anthropic_vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
wd0517 committed Jul 24, 2024
1 parent 3435a8d commit 92711d9
Showing 1 changed file with 89 additions and 45 deletions.
134 changes: 89 additions & 45 deletions backend/app/rag/llms/anthropic_vertex/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import anthropic
import json
from anthropic.types import ContentBlockDeltaEvent, TextBlock
from anthropic.types import (
ContentBlockDeltaEvent,
TextBlock,
TextDelta,
ContentBlockStartEvent,
ContentBlockStopEvent,
)
from anthropic.types.tool_use_block import ToolUseBlock
from typing import (
Any,
Expand All @@ -13,6 +19,7 @@
Union,
TYPE_CHECKING,
)

from google.oauth2 import service_account
from llama_index.core.base.llms.types import (
ChatMessage,
Expand Down Expand Up @@ -41,6 +48,7 @@
from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.core.utils import Tokenizer
from llama_index.core.llms.utils import parse_partial_json

from .utils import (
anthropic_modelname_to_contextsize,
Expand All @@ -50,7 +58,6 @@
)

if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.types import BaseTool


Expand All @@ -59,7 +66,7 @@


class AnthropicVertex(FunctionCallingLLM):
"""Anthropic LLM.
"""AnthropicVertex LLM.
Examples:
`pip install llama-index-llms-anthropic`
Expand Down Expand Up @@ -249,16 +256,48 @@ def stream_chat(

def gen() -> ChatResponseGen:
content = ""
cur_tool_calls: List[ToolUseBlock] = []
cur_tool_call: Optional[ToolUseBlock] = None
cur_tool_json: str = ""
role = MessageRole.ASSISTANT
for r in response:
if isinstance(r, ContentBlockDeltaEvent):
content_delta = r.delta.text
content += content_delta
if isinstance(r.delta, TextDelta):
content_delta = r.delta.text
content += content_delta
else:
if not isinstance(cur_tool_call, ToolUseBlock):
raise ValueError("Tool call not started")
content_delta = r.delta.partial_json
cur_tool_json += content_delta
try:
argument_dict = parse_partial_json(cur_tool_json)
cur_tool_call.input = argument_dict
except ValueError:
pass

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls
yield ChatResponse(
message=ChatMessage(role=role, content=content),
message=ChatMessage(
role=role,
content=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
delta=content_delta,
raw=r,
)
elif isinstance(r, ContentBlockStartEvent):
if isinstance(r.content_block, ToolUseBlock):
cur_tool_call = r.content_block
cur_tool_json = ""
elif isinstance(r, ContentBlockStopEvent):
if isinstance(cur_tool_call, ToolUseBlock):
cur_tool_calls.append(cur_tool_call)

return gen()

Expand Down Expand Up @@ -314,16 +353,48 @@ async def astream_chat(

async def gen() -> ChatResponseAsyncGen:
content = ""
cur_tool_calls: List[ToolUseBlock] = []
cur_tool_call: Optional[ToolUseBlock] = None
cur_tool_json: str = ""
role = MessageRole.ASSISTANT
async for r in response:
if isinstance(r, ContentBlockDeltaEvent):
content_delta = r.delta.text
content += content_delta
if isinstance(r.delta, TextDelta):
content_delta = r.delta.text
content += content_delta
else:
if not isinstance(cur_tool_call, ToolUseBlock):
raise ValueError("Tool call not started")
content_delta = r.delta.partial_json
cur_tool_json += content_delta
try:
argument_dict = parse_partial_json(cur_tool_json)
cur_tool_call.input = argument_dict
except ValueError:
pass

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls
yield ChatResponse(
message=ChatMessage(role=role, content=content),
message=ChatMessage(
role=role,
content=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
delta=content_delta,
raw=r,
)
elif isinstance(r, ContentBlockStartEvent):
if isinstance(r.content_block, ToolUseBlock):
cur_tool_call = r.content_block
cur_tool_json = ""
elif isinstance(r, ContentBlockStopEvent):
if isinstance(cur_tool_call, ToolUseBlock):
cur_tool_calls.append(cur_tool_call)

return gen()

Expand All @@ -334,16 +405,16 @@ async def astream_complete(
astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
return await astream_complete_fn(prompt, **kwargs)

def chat_with_tools(
def _prepare_chat_with_tools(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> ChatResponse:
"""Predict and call the tool."""
) -> Dict[str, Any]:
"""Prepare the chat with tools."""
chat_history = chat_history or []

if isinstance(user_msg, str):
Expand All @@ -359,50 +430,23 @@ def chat_with_tools(
"input_schema": tool.metadata.get_parameters_dict(),
}
)
return {"messages": chat_history, "tools": tool_dicts or None, **kwargs}

response = self.chat(chat_history, tools=tool_dicts or None, **kwargs)

if not allow_parallel_tool_calls:
force_single_tool_call(response)

return response

async def achat_with_tools(
def _validate_chat_with_tools_response(
self,
response: ChatResponse,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> ChatResponse:
"""Predict and call the tool."""
chat_history = chat_history or []

if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
chat_history.append(user_msg)

tool_dicts = []
for tool in tools:
tool_dicts.append(
{
"name": tool.metadata.name,
"description": tool.metadata.description,
"input_schema": tool.metadata.get_parameters_dict(),
}
)

response = await self.achat(chat_history, tools=tool_dicts or None, **kwargs)

"""Validate the response from chat_with_tools."""
if not allow_parallel_tool_calls:
force_single_tool_call(response)

return response

def get_tool_calls_from_response(
self,
response: "AgentChatResponse",
response: "ChatResponse",
error_on_no_tool_call: bool = True,
**kwargs: Any,
) -> List[ToolSelection]:
Expand Down Expand Up @@ -441,4 +485,4 @@ def get_tool_calls_from_response(
)
)

return tool_selections
return tool_selections

0 comments on commit 92711d9

Please sign in to comment.