Skip to content

Commit

Permalink
Add async versions
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 4, 2024
1 parent 38332fd commit 8138c88
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,103 @@ async def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"]


async def test_concurrent_emit_sends() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

async def __call__(self, state):
return (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)

async def send_for_fun(state):
return [Send("2", 1), Send("2", 2), "3.1"]

async def send_for_profit(state):
return [Send("2", 3), Send("2", 4)]

async def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("1.1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_edge(START, "1.1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("1.1", send_for_profit)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"1.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3.1",
"3",
]


async def test_send_sequences() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

async def __call__(self, state):
update = (
[self.name]
if isinstance(state, list) # or isinstance(state, Control)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Control):
state.update_state = update
return state
else:
return update

async def send_for_fun(state):
return [
Send("2", Control(send=Send("2", 3))),
Send("2", Control(send=Send("2", 4))),
"3.1",
]

async def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3.1",
"2|3",
"2|4",
"3",
"3",
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_max_concurrency(checkpointer_name: str) -> None:
class Node:
Expand Down

0 comments on commit 8138c88

Please sign in to comment.