Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] langgraph: allow tools / ToolNode to modify state #2339

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 57.4 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 49.0 ms +- 0.8 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 82.4 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 90.8 ms +- 1.2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 565 ms +- 15 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 480 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 867 ms +- 36 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 886 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 30.3 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.7 ms +- 1.0 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.4 ms +- 0.6 ms ......................................... react_agent_100x: Mean +- std dev: 337 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 270 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 932 ms +- 14 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 836 ms +- 12 ms ......................................... wide_state_25x300: Mean +- std dev: 23.0 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.7 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 278 ms +- 4 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 266 ms +- 4 ms ......................................... wide_state_15x600: Mean +- std dev: 26.7 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 16.9 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 479 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 467 ms +- 16 ms ......................................... wide_state_9x1200: Mean +- std dev: 26.7 ms +- 0.6 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.0 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 312 ms +- 7 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 299 ms +- 4 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 908 ms | 867 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 898 ms | 886 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 937 ms | 932 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 26.8 ms | 26.7 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 16.9 ms | 16.9 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 477 ms | 479 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 16.9 ms | 17.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 264 ms | 266 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 48.7 ms | 49.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 831 ms | 836 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 268 ms | 270 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.4 ms | 46.7 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 297 ms | 299 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 89.8 ms | 90.8 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.0 ms | 36.4 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 472 ms | 480 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 529 ms | 565 ms: 1.07x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (11): fanout_to_subgraph_10x, wide_state_25x300, fanout_to_subgraph_10x_checkpoint, wide_state_9x1200, wide_state_25x300_checkpoint, wide_state_9x1200_checkpoint, wide_state_25x300_sync, react_agent_10x, react_agent_100x, react_agent_10x_sync, wide_state_15x600_checkpoint_sync

import asyncio
import inspect
Expand All @@ -11,6 +11,7 @@
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -50,6 +51,11 @@
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


class StateUpdateArtifact(NamedTuple):
state_update: dict[str, Any]
artifact: Any = None


def msg_content_output(output: Any) -> str | List[dict]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
Expand Down Expand Up @@ -221,8 +227,33 @@
config_list = get_config_list(config, len(tool_calls))
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, config_list)]
outputs, state_updates = zip(*outputs)
combined_state_updates = {}
for state_update in state_updates:
for k, v in state_update.items():
if k == self.messages_key:
raise ValueError(
"Cannot return state updates for the messages key."
)

if k in combined_state_updates:
raise ValueError(
f"Received multiple state updates for the key: {k}"
)

combined_state_updates[k] = v

if output_type == "list" and combined_state_updates:
raise ValueError(
"Cannot return state updates for a list input to ToolNode."
)

# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}
return (
outputs
if output_type == "list"
else {self.messages_key: list(outputs), **combined_state_updates}
)

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -256,9 +287,11 @@
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}

def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
def _run_one(
self, call: ToolCall, config: RunnableConfig
) -> tuple[ToolMessage, dict[str, Any]]:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
return invalid_tool_message, {}

try:
input = {**call, **{"type": "tool_call"}}
Expand All @@ -268,7 +301,12 @@
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
if isinstance(tool_message.artifact, StateUpdateArtifact):
state_update = tool_message.artifact.state_update
tool_message.artifact = tool_message.artifact.artifact
return tool_message, state_update

return tool_message, {}
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
Expand All @@ -295,7 +333,7 @@

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
), {}

async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
Expand Down
Loading