Skip to content

Commit

Permalink
Merge pull request #2327 from langchain-ai/nc/4nov/send-tests
Browse files Browse the repository at this point in the history
lib: Add two more tests for Send
  • Loading branch information
nfcampos authored Nov 4, 2024
2 parents 895079b + 8138c88 commit 5e4c928
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 1 deletion.
6 changes: 6 additions & 0 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def __init__(
self.trigger = trigger
self.send = send

def __repr__(self) -> str:
contents = ", ".join(
f"{key}={value!r}" for key, value in self.__dict__.items() if value
)
return f"Control({contents})"


StreamChunk = tuple[tuple[str, ...], str, Any]

Expand Down
99 changes: 98 additions & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from langgraph.pregel.retry import RetryPolicy
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import Interrupt, PregelTask, Send, StreamWriter
from langgraph.types import Control, Interrupt, PregelTask, Send, StreamWriter
from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence
from tests.conftest import (
ALL_CHECKPOINTERS_SYNC,
Expand Down Expand Up @@ -1743,6 +1743,103 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"]


def test_concurrent_emit_sends() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

def __call__(self, state):
return (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)

def send_for_fun(state):
return [Send("2", 1), Send("2", 2), "3.1"]

def send_for_profit(state):
return [Send("2", 3), Send("2", 4)]

def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("1.1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_edge(START, "1.1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("1.1", send_for_profit)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert graph.invoke(["0"]) == [
"0",
"1",
"1.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3.1",
"3",
]


def test_send_sequences() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

def __call__(self, state):
update = (
[self.name]
if isinstance(state, list) # or isinstance(state, Control)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Control):
state.update_state = update
return state
else:
return update

def send_for_fun(state):
return [
Send("2", Control(send=Send("2", 3))),
Send("2", Control(send=Send("2", 4))),
"3.1",
]

def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert graph.invoke(["0"]) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3.1",
"2|3",
"2|4",
"3",
"3",
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_invoke_checkpoint_three(
mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str
Expand Down
97 changes: 97 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,103 @@ async def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"]


async def test_concurrent_emit_sends() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

async def __call__(self, state):
return (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)

async def send_for_fun(state):
return [Send("2", 1), Send("2", 2), "3.1"]

async def send_for_profit(state):
return [Send("2", 3), Send("2", 4)]

async def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("1.1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_edge(START, "1.1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("1.1", send_for_profit)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"1.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3.1",
"3",
]


async def test_send_sequences() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)

async def __call__(self, state):
update = (
[self.name]
if isinstance(state, list) # or isinstance(state, Control)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Control):
state.update_state = update
return state
else:
return update

async def send_for_fun(state):
return [
Send("2", Control(send=Send("2", 3))),
Send("2", Control(send=Send("2", 4))),
"3.1",
]

async def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3.1",
"2|3",
"2|4",
"3",
"3",
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_max_concurrency(checkpointer_name: str) -> None:
class Node:
Expand Down

0 comments on commit 5e4c928

Please sign in to comment.