From 3c0895acc1d997c00bcd01239837edd613115fe2 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 26 Jul 2024 15:06:05 -0400 Subject: [PATCH] use tool node --- backend/app/graphs/new_agent.py | 34 ++++----------------------------- 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/backend/app/graphs/new_agent.py b/backend/app/graphs/new_agent.py index e28844a0..57d0c20a 100644 --- a/backend/app/graphs/new_agent.py +++ b/backend/app/graphs/new_agent.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Dict, TypedDict, cast +from typing import Annotated, Any, Dict, TypedDict from langchain_core.messages import ( AIMessage, @@ -12,7 +12,7 @@ from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages from langgraph.managed.few_shot import FewShotExamples -from langgraph.prebuilt import ToolExecutor, ToolInvocation +from langgraph.prebuilt import ToolNode from app.llms import LLMType, get_llm from app.message_types import LiberalToolMessage @@ -148,7 +148,6 @@ def should_continue(state): # Define the function to execute tools async def call_tool(state, config): - messages = state["messages"] _config = config["configurable"] tools = get_tools( _config.get("type==agent/tools"), @@ -157,33 +156,8 @@ async def call_tool(state, config): _config.get("type==agent/retrieval_description"), ) - tool_executor = ToolExecutor(tools) - actions: list[ToolInvocation] = [] - # Based on the continue condition - # we know the last message involves a function call - last_message = cast(AIMessage, messages[-1]) - for tool_call in last_message.tool_calls: - # We construct a ToolInvocation from the function_call - actions.append( - ToolInvocation( - tool=tool_call["name"], - tool_input=tool_call["args"], - ) - ) - # We call the tool_executor and get back a response - responses = await tool_executor.abatch(actions) - # We use the response to create a ToolMessage - tool_messages = [ - LiberalToolMessage( - tool_call_id=tool_call["id"], - name=tool_call["name"], - content=response, - ) - for tool_call, response in zip(last_message.tool_calls, responses) - ] - - # graph state is a dict, so return type must be dict - return {"messages": tool_messages} + tool_node = ToolNode(tools) + return await tool_node.ainvoke(state) workflow = StateGraph(BaseState)