From c6a450b857d4740a70f7e12bd1d56673a0865e7d Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 11 Oct 2024 13:53:06 -0700 Subject: [PATCH 1/2] lib: Add interrupts to stream_mode=updates --- libs/langgraph/langgraph/pregel/loop.py | 11 ++----- libs/langgraph/tests/test_pregel.py | 39 +++++++++++++++++++++-- libs/langgraph/tests/test_pregel_async.py | 37 ++++++++++++++++----- libs/langgraph/tests/test_runnable.py | 4 +++ 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index b3bffa544..7b5b343af 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -346,10 +346,7 @@ def tick( # after execution, check if we should interrupt if should_interrupt(self.checkpoint, interrupt_after, self.tasks.values()): self.status = "interrupt_after" - if self.is_nested: - raise GraphInterrupt() - else: - return False + raise GraphInterrupt() else: return False @@ -441,10 +438,7 @@ def tick( # before execution, check if we should interrupt if should_interrupt(self.checkpoint, interrupt_before, self.tasks.values()): self.status = "interrupt_before" - if self.is_nested: - raise GraphInterrupt() - else: - return False + raise GraphInterrupt() # produce debug output self._emit("debug", map_debug_tasks, self.step, self.tasks.values()) @@ -598,6 +592,7 @@ def _suppress_interrupt( self.output = read_channels(self.channels, self.output_keys) if suppress: # suppress interrupt + self._emit("updates", lambda: iter([{INTERRUPT: exc_value.args[0]}])) return True def _emit( diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 7ee8dc926..0acf99fe8 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -896,6 +896,7 @@ def test_invoke_two_processes_in_out_interrupt( ] assert [c for c in app.stream(None, history[2].config, stream_mode="updates")] == [ {"one": {"inbox": 4}}, + {"__interrupt__": ()}, ] @@ -3198,6 +3199,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3295,6 +3297,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] with assert_ctx_once(): @@ -3365,6 +3368,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3460,6 +3464,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] app_w_interrupt.update_state( @@ -3520,7 +3525,9 @@ def should_continue(data: AgentState) -> str: assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) - ] == [] + ] == [ + {"__interrupt__": ()}, + ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ @@ -3542,6 +3549,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3587,6 +3595,7 @@ def should_continue(data: AgentState) -> str: ], } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3641,6 +3650,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] # test w interrupt after all @@ -3661,6 +3671,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3706,6 +3717,7 @@ def should_continue(data: AgentState) -> str: ], } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -3760,6 +3772,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] @@ -4630,6 +4643,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: ) } }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -4759,6 +4773,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: ) }, }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5130,6 +5145,7 @@ def should_continue(messages): id="ai1", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5241,6 +5257,7 @@ def should_continue(messages): id="ai2", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5360,6 +5377,7 @@ def should_continue(messages): id="ai1", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5471,6 +5489,7 @@ def should_continue(messages): id="ai2", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5856,6 +5875,7 @@ class State(TypedDict): id="ai1", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -5967,6 +5987,7 @@ class State(TypedDict): id="ai2", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -6088,6 +6109,7 @@ class State(TypedDict): id="ai1", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -6199,6 +6221,7 @@ class State(TypedDict): id="ai2", ) }, + {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( @@ -7653,6 +7676,7 @@ def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -7672,6 +7696,7 @@ def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] app_w_interrupt.update_state(config, {"docs": ["doc5"]}) @@ -7785,6 +7810,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -7941,6 +7967,7 @@ def decider(data: State) -> str: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] with assert_ctx_once(): @@ -8109,6 +8136,7 @@ def decider(data: State) -> str: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] with assert_ctx_once(): @@ -8217,6 +8245,7 @@ def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -8779,6 +8808,7 @@ def outer_2(state: State): # we got to parallel node first ((), {"outer_1": {"my_key": " and parallel"}}), ((AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}), + ((), {"__interrupt__": ()}), ] assert [*app.stream(None, config)] == [ {"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}}, @@ -8898,6 +8928,7 @@ def parent_2(state: State): config = {"configurable": {"thread_id": "2"}} assert [*app.stream({"my_key": "my value"}, config)] == [ {"parent_1": {"my_key": "hi my value"}}, + {"__interrupt__": ()}, ] assert [*app.stream(None, config)] == [ {"child": {"my_key": "hi my value here and there"}}, @@ -9493,6 +9524,7 @@ def parent_2(state: State): (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), + ((), {"__interrupt__": ()}), ] # get state without subgraphs outer_state = app.get_state(config) @@ -10192,7 +10224,8 @@ def parent_2(state: State): ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, - ) + ), + ((), {"__interrupt__": ()}), ] @@ -10644,6 +10677,7 @@ def weather_graph(state: RouterState): ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), + ((), {"__interrupt__": ()}), ] # check current state @@ -10732,6 +10766,7 @@ def weather_graph(state: RouterState): ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), + ((), {"__interrupt__": ()}), ] state = graph.get_state(config, subgraphs=True) assert state == StateSnapshot( diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 5a25a3f01..044c1b800 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -291,12 +291,14 @@ async def tool_two_node(s: State) -> State: thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node - assert await tool_two.ainvoke( - {"my_key": "value ⛰️", "market": "DE"}, thread1 - ) == { - "my_key": "value ⛰️", - "market": "DE", - } + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread1 + ) + ] == [ + {"__interrupt__": [Interrupt(value="Just because...", when="during")]}, + ] assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ { "parents": {}, @@ -330,7 +332,6 @@ async def tool_two_node(s: State) -> State: c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) - # TODO use aget_state_history @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) @@ -1103,6 +1104,7 @@ async def test_invoke_two_processes_in_out_interrupt( c async for c in app.astream(None, history[2].config, stream_mode="updates") ] == [ {"one": {"inbox": 4}}, + {"__interrupt__": ()}, ] @@ -3452,6 +3454,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( @@ -3559,6 +3562,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] async with assert_ctx_once(): @@ -3637,6 +3641,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( @@ -3740,6 +3745,7 @@ def should_continue(data: AgentState) -> str: ), } }, + {"__interrupt__": ()}, ] await app_w_interrupt.aupdate_state( @@ -4448,6 +4454,7 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: ) } }, + {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( @@ -4581,6 +4588,7 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: ) }, }, + {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) @@ -4922,6 +4930,7 @@ def should_continue(messages): id="ai1", ) }, + {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) @@ -5039,6 +5048,7 @@ def should_continue(messages): id="ai2", ) }, + {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) @@ -6434,6 +6444,7 @@ async def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -6525,6 +6536,7 @@ async def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -6668,6 +6680,7 @@ async def decider(data: State) -> str: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] async with assert_ctx_once(): @@ -6833,6 +6846,7 @@ async def decider(data: State) -> str: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -6939,6 +6953,7 @@ async def qa(data: State) -> State: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -7419,6 +7434,7 @@ async def outer_2(state: State): (AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}, ), + ((), {"__interrupt__": ()}), ] assert [c async for c in app.astream(None, config)] == [ {"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}}, @@ -7541,6 +7557,7 @@ async def parent_2(state: State): config = {"configurable": {"thread_id": "2"}} assert [c async for c in app.astream({"my_key": "my value"}, config)] == [ {"parent_1": {"my_key": "hi my value"}}, + {"__interrupt__": ()}, ] assert [c async for c in app.astream(None, config)] == [ {"child": {"my_key": "hi my value here and there"}}, @@ -8163,6 +8180,7 @@ def parent_2(state: State): (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), + ((), {"__interrupt__": ()}), ] # get state without subgraphs outer_state = await app.aget_state(config) @@ -8895,7 +8913,8 @@ def parent_2(state: State): ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, - ) + ), + ((), {"__interrupt__": ()}), ] @@ -9297,6 +9316,7 @@ def get_first_in_list(): ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), + ((), {"__interrupt__": ()}), ] # check current state @@ -9389,6 +9409,7 @@ def get_first_in_list(): ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), + ((), {"__interrupt__": ()}), ] state = await graph.aget_state(config, subgraphs=True) assert state == StateSnapshot( diff --git a/libs/langgraph/tests/test_runnable.py b/libs/langgraph/tests/test_runnable.py index 2bd2e20d5..6cb4b5e9e 100644 --- a/libs/langgraph/tests/test_runnable.py +++ b/libs/langgraph/tests/test_runnable.py @@ -2,10 +2,14 @@ from typing import Any +import pytest + from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.runnable import RunnableCallable +pytestmark = pytest.mark.anyio + def test_runnable_callable_func_accepts(): def sync_func(x: Any) -> str: From 561aa3080ebd61b2095e624ecc941c4c6ed477d1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 11 Oct 2024 14:27:51 -0700 Subject: [PATCH 2/2] Lint --- libs/langgraph/langgraph/pregel/loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 7b5b343af..f54c07047 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -592,7 +592,10 @@ def _suppress_interrupt( self.output = read_channels(self.channels, self.output_keys) if suppress: # suppress interrupt - self._emit("updates", lambda: iter([{INTERRUPT: exc_value.args[0]}])) + self._emit( + "updates", + lambda: iter([{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]), + ) return True def _emit(