From 8138c88b41f2823f582371f55cdf1476cb760253 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 4 Nov 2024 11:56:31 -0800 Subject: [PATCH] Add async versions --- libs/langgraph/tests/test_pregel_async.py | 97 +++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 25072eeaf..d8299c930 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1962,6 +1962,103 @@ async def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"] +async def test_concurrent_emit_sends() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + async def __call__(self, state): + return ( + [self.name] + if isinstance(state, list) + else ["|".join((self.name, str(state)))] + ) + + async def send_for_fun(state): + return [Send("2", 1), Send("2", 2), "3.1"] + + async def send_for_profit(state): + return [Send("2", 3), Send("2", 4)] + + async def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("1.1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_edge(START, "1.1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("1.1", send_for_profit) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert await graph.ainvoke(["0"]) == [ + "0", + "1", + "1.1", + "2|1", + "2|2", + "2|3", + "2|4", + "3.1", + "3", + ] + + +async def test_send_sequences() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + async def __call__(self, state): + update = ( + [self.name] + if isinstance(state, list) # or isinstance(state, Control) + else ["|".join((self.name, str(state)))] + ) + if isinstance(state, Control): + state.update_state = update + return state + else: + return update + + async def send_for_fun(state): + return [ + Send("2", Control(send=Send("2", 3))), + Send("2", Control(send=Send("2", 4))), + "3.1", + ] + + async def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(Annotated[list, operator.add]) + builder.add_node(Node("1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_node(Node("3.1")) + builder.add_edge(START, "1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + assert await graph.ainvoke(["0"]) == [ + "0", + "1", + "2|Control(send=Send(node='2', arg=3))", + "2|Control(send=Send(node='2', arg=4))", + "3.1", + "2|3", + "2|4", + "3", + "3", + ] + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency(checkpointer_name: str) -> None: class Node: