Skip to content

Commit

Permalink
use tool node
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Jul 26, 2024
1 parent d11cf3b commit 3c0895a
Showing 1 changed file with 4 additions and 30 deletions.
34 changes: 4 additions & 30 deletions backend/app/graphs/new_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Any, Dict, TypedDict, cast
from typing import Annotated, Any, Dict, TypedDict

from langchain_core.messages import (
AIMessage,
Expand All @@ -12,7 +12,7 @@
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.managed.few_shot import FewShotExamples
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from langgraph.prebuilt import ToolNode

from app.llms import LLMType, get_llm
from app.message_types import LiberalToolMessage
Expand Down Expand Up @@ -148,7 +148,6 @@ def should_continue(state):

# Define the function to execute tools
async def call_tool(state, config):
messages = state["messages"]
_config = config["configurable"]
tools = get_tools(
_config.get("type==agent/tools"),
Expand All @@ -157,33 +156,8 @@ async def call_tool(state, config):
_config.get("type==agent/retrieval_description"),
)

tool_executor = ToolExecutor(tools)
actions: list[ToolInvocation] = []
# Based on the continue condition
# we know the last message involves a function call
last_message = cast(AIMessage, messages[-1])
for tool_call in last_message.tool_calls:
# We construct a ToolInvocation from the function_call
actions.append(
ToolInvocation(
tool=tool_call["name"],
tool_input=tool_call["args"],
)
)
# We call the tool_executor and get back a response
responses = await tool_executor.abatch(actions)
# We use the response to create a ToolMessage
tool_messages = [
LiberalToolMessage(
tool_call_id=tool_call["id"],
name=tool_call["name"],
content=response,
)
for tool_call, response in zip(last_message.tool_calls, responses)
]

# graph state is a dict, so return type must be dict
return {"messages": tool_messages}
tool_node = ToolNode(tools)
return await tool_node.ainvoke(state)


workflow = StateGraph(BaseState)
Expand Down

0 comments on commit 3c0895a

Please sign in to comment.