diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index f7f330b83..1ed6e43ad 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -238,6 +238,12 @@ def __init__( self.trigger = trigger self.send = send + def __repr__(self) -> str: + contents = ", ".join( + f"{key}={value!r}" for key, value in self.__dict__.items() if value + ) + return f"Control({contents})" + StreamChunk = tuple[tuple[str, ...], str, Any] diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 563ce295e..8536a08c3 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -74,7 +74,7 @@ from langgraph.pregel.retry import RetryPolicy from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore -from langgraph.types import Interrupt, PregelTask, Send, StreamWriter +from langgraph.types import Control, Interrupt, PregelTask, Send, StreamWriter from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence from tests.conftest import ( ALL_CHECKPOINTERS_SYNC, @@ -1743,6 +1743,103 @@ def route_to_three(state) -> Literal["3"]: assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"] +def test_concurrent_emit_sends() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + def __call__(self, state): + return ( + [self.name] + if isinstance(state, list) + else ["|".join((self.name, str(state)))] + ) + + def send_for_fun(state): + return [Send("2", 1), Send("2", 2), "3.1"] + + def send_for_profit(state): + return [Send("2", 3), Send("2", 4)] + + def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("1.1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_edge(START, "1.1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("1.1", send_for_profit) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert graph.invoke(["0"]) == [ + "0", + "1", + "1.1", + "2|1", + "2|2", + "2|3", + "2|4", + "3.1", + "3", + ] + + +def test_send_sequences() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + def __call__(self, state): + update = ( + [self.name] + if isinstance(state, list) # or isinstance(state, Control) + else ["|".join((self.name, str(state)))] + ) + if isinstance(state, Control): + state.update_state = update + return state + else: + return update + + def send_for_fun(state): + return [ + Send("2", Control(send=Send("2", 3))), + Send("2", Control(send=Send("2", 4))), + "3.1", + ] + + def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert graph.invoke(["0"]) == [ + "0", + "1", + "2|Control(send=Send(node='2', arg=3))", + "2|Control(send=Send(node='2', arg=4))", + "3.1", + "2|3", + "2|4", + "3", + "3", + ] + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_checkpoint_three( mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 25072eeaf..d8299c930 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1962,6 +1962,103 @@ async def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"] +async def test_concurrent_emit_sends() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + async def __call__(self, state): + return ( + [self.name] + if isinstance(state, list) + else ["|".join((self.name, str(state)))] + ) + + async def send_for_fun(state): + return [Send("2", 1), Send("2", 2), "3.1"] + + async def send_for_profit(state): + return [Send("2", 3), Send("2", 4)] + + async def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("1.1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_edge(START, "1.1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("1.1", send_for_profit) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert await graph.ainvoke(["0"]) == [ + "0", + "1", + "1.1", + "2|1", + "2|2", + "2|3", + "2|4", + "3.1", + "3", + ] + + +async def test_send_sequences() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + async def __call__(self, state): + update = ( + [self.name] + if isinstance(state, list) # or isinstance(state, Control) + else ["|".join((self.name, str(state)))] + ) + if isinstance(state, Control): + state.update_state = update + return state + else: + return update + + async def send_for_fun(state): + return [ + Send("2", Control(send=Send("2", 3))), + Send("2", Control(send=Send("2", 4))), + "3.1", + ] + + async def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert await graph.ainvoke(["0"]) == [ + "0", + "1", + "2|Control(send=Send(node='2', arg=3))", + "2|Control(send=Send(node='2', arg=4))", + "3.1", + "2|3", + "2|4", + "3", + "3", + ] + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency(checkpointer_name: str) -> None: class Node: