From ae282e3ae10327307c889ddd2061a0400fb4d5e8 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 4 Nov 2024 15:41:02 -0800 Subject: [PATCH] lib: Add test for react architecture using Send + interrupt_before - Both for cond edge and edgeless graphs --- libs/langgraph/tests/test_pregel.py | 502 +++++++++++++++++++++- libs/langgraph/tests/test_pregel_async.py | 496 ++++++++++++++++++++- 2 files changed, 996 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index a8a614408..c9de54ec8 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -58,7 +58,7 @@ from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph from langgraph.graph.graph import START -from langgraph.graph.message import MessageGraph, add_messages +from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.graph.state import StateGraph from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import ( @@ -1845,6 +1845,506 @@ def route_to_three(state) -> Literal["3"]: ] +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_send_react_interrupt( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage + + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + ai_message = AIMessage( + "", + id="ai1", + tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], + ) + + def agent(state): + return {"messages": ai_message} + + def route(state): + if isinstance(state["messages"][-1], AIMessage): + return [ + Send(call["name"], call) for call in state["messages"][-1].tool_calls + ] + + foo_called = 0 + + def foo(call: ToolCall): + nonlocal foo_called + foo_called += 1 + return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} + + builder = StateGraph(MessagesState) + builder.add_node(agent) + builder.add_node(foo) + builder.add_edge(START, "agent") + builder.add_conditional_edges("agent", route) + graph = builder.compile() + + assert graph.invoke({"messages": [HumanMessage("hello")]}) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # simple interrupt-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "1"}} + assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + assert graph.invoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # interrupt-update-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "2"}} + assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + + # get state should show the pending task + state = graph.get_state(thread1) + assert state == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + }, + next=("foo",), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 1, + "source": "loop", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=( + PregelTask( + id=AnyStr(), + name="foo", + path=("__pregel_push", 0), + error=None, + interrupts=(), + state=None, + result=None, + ), + ), + ) + + # remove the tool call, clearing the pending task + graph.update_state( + thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} + ) + + # tool call no longer in pending tasks + assert graph.get_state(thread1) == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ), + ] + }, + next=(), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 2, + "source": "update", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=(), + ) + + # tool call not executed + assert graph.invoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage(content="Bye now"), + ] + } + assert foo_called == 0 + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_send_react_interrupt_control( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage + + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + ai_message = AIMessage( + "", + id="ai1", + tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], + ) + + def agent(state) -> Control[Literal["foo"]]: + return Control( + update_state={"messages": ai_message}, + send=[Send(call["name"], call) for call in ai_message.tool_calls], + ) + + foo_called = 0 + + def foo(call: ToolCall): + nonlocal foo_called + foo_called += 1 + return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} + + builder = StateGraph(MessagesState) + builder.add_node(agent) + builder.add_node(foo) + builder.add_edge(START, "agent") + graph = builder.compile() + + assert graph.invoke({"messages": [HumanMessage("hello")]}) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # simple interrupt-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "1"}} + assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + assert graph.invoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # interrupt-update-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "2"}} + assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + + # get state should show the pending task + state = graph.get_state(thread1) + assert state == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + }, + next=("foo",), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 1, + "source": "loop", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=( + PregelTask( + id=AnyStr(), + name="foo", + path=("__pregel_push", 0), + error=None, + interrupts=(), + state=None, + result=None, + ), + ), + ) + + # remove the tool call, clearing the pending task + graph.update_state( + thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} + ) + + # tool call no longer in pending tasks + assert graph.get_state(thread1) == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ), + ] + }, + next=(), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 2, + "source": "update", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=(), + ) + + # tool call not executed + assert graph.invoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage(content="Bye now"), + ] + } + assert foo_called == 0 + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_checkpoint_three( mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index b9d7936db..1235d5cc0 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -55,7 +55,7 @@ from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph, StateGraph from langgraph.graph.graph import START -from langgraph.graph.message import MessageGraph, add_messages +from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor from langgraph.prebuilt.tool_node import ToolNode @@ -2065,6 +2065,500 @@ async def route_to_three(state) -> Literal["3"]: ] +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_send_react_interrupt(checkpointer_name: str) -> None: + from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage + + ai_message = AIMessage( + "", + id="ai1", + tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], + ) + + async def agent(state): + return {"messages": ai_message} + + def route(state): + if isinstance(state["messages"][-1], AIMessage): + return [ + Send(call["name"], call) for call in state["messages"][-1].tool_calls + ] + + foo_called = 0 + + async def foo(call: ToolCall): + nonlocal foo_called + foo_called += 1 + return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} + + builder = StateGraph(MessagesState) + builder.add_node(agent) + builder.add_node(foo) + builder.add_edge(START, "agent") + builder.add_conditional_edges("agent", route) + graph = builder.compile() + + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + async with awith_checkpointer(checkpointer_name) as checkpointer: + # simple interrupt-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "1"}} + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + assert await graph.ainvoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # interrupt-update-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "2"}} + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + + # get state should show the pending task + state = await graph.aget_state(thread1) + assert state == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + }, + next=("foo",), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 1, + "source": "loop", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=( + PregelTask( + id=AnyStr(), + name="foo", + path=("__pregel_push", 0), + error=None, + interrupts=(), + state=None, + result=None, + ), + ), + ) + + # remove the tool call, clearing the pending task + await graph.aupdate_state( + thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} + ) + + # tool call no longer in pending tasks + assert await graph.aget_state(thread1) == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ), + ] + }, + next=(), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 2, + "source": "update", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=(), + ) + + # tool call not executed + assert await graph.ainvoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage(content="Bye now"), + ] + } + assert foo_called == 0 + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_send_react_interrupt_control(checkpointer_name: str) -> None: + from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage + + ai_message = AIMessage( + "", + id="ai1", + tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], + ) + + async def agent(state) -> Control[Literal["foo"]]: + return Control( + update_state={"messages": ai_message}, + send=[Send(call["name"], call) for call in ai_message.tool_calls], + ) + + foo_called = 0 + + async def foo(call: ToolCall): + nonlocal foo_called + foo_called += 1 + return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} + + builder = StateGraph(MessagesState) + builder.add_node(agent) + builder.add_node(foo) + builder.add_edge(START, "agent") + graph = builder.compile() + + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + async with awith_checkpointer(checkpointer_name) as checkpointer: + # simple interrupt-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "1"}} + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + assert await graph.ainvoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="{'hi': [1, 2, 3]}", + tool_call_id=AnyStr(), + ), + ] + } + assert foo_called == 1 + + # interrupt-update-resume flow + foo_called = 0 + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) + thread1 = {"configurable": {"thread_id": "2"}} + assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + } + assert foo_called == 0 + + # get state should show the pending task + state = await graph.aget_state(thread1) + assert state == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ), + ] + }, + next=("foo",), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 1, + "source": "loop", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "args": {"hi": [1, 2, 3]}, + "id": "", + "type": "tool_call", + } + ], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=( + PregelTask( + id=AnyStr(), + name="foo", + path=("__pregel_push", 0), + error=None, + interrupts=(), + state=None, + result=None, + ), + ), + ) + + # remove the tool call, clearing the pending task + await graph.aupdate_state( + thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} + ) + + # tool call no longer in pending tasks + assert await graph.aget_state(thread1) == StateSnapshot( + values={ + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ), + ] + }, + next=(), + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + metadata={ + "step": 2, + "source": "update", + "writes": { + "agent": { + "messages": _AnyIdAIMessage( + content="Bye now", + tool_calls=[], + ) + } + }, + "parents": {}, + "thread_id": "2", + }, + created_at=AnyStr(), + parent_config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + tasks=(), + ) + + # tool call not executed + assert await graph.ainvoke(None, thread1) == { + "messages": [ + _AnyIdHumanMessage(content="hello"), + _AnyIdAIMessage(content="Bye now"), + ] + } + assert foo_called == 0 + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency(checkpointer_name: str) -> None: class Node: