Skip to content

Commit

Permalink
more code review
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Nov 1, 2024
1 parent 0ad50e9 commit 2029765
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 98 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/prebuilt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 48.1 ms +- 0.9 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 43.5 ms +- 0.5 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.3 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 83.6 ms +- 0.7 ms ......................................... WARNING: the benchmark result may be unstable * the maximum (724 ms) is 56% greater than the mean (466 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_100x: Mean +- std dev: 466 ms +- 26 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 426 ms +- 19 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 781 ms +- 41 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 830 ms +- 18 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (3.31 ms) is 11% of the mean (31.0 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. react_agent_10x: Mean +- std dev: 31.0 ms +- 3.3 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.4 ms +- 1.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.3 ms +- 3.2 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.8 ms +- 3.0 ms ......................................... react_agent_100x: Mean +- std dev: 316 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 252 ms +- 2 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 900 ms +- 12 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 807 ms +- 5 ms ......................................... wide_state_25x300: Mean +- std dev: 18.4 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 10.8 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 279 ms +- 12 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 269 ms +- 12 ms ......................................... wide_state_15x600: Mean +- std dev: 21.2 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 12.4 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 479 ms +- 12 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 464 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 21.2 ms +- 0.4 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 12.4 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 312 ms +- 13 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 300 ms +- 12 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 815 ms | 781 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 309 ms | 300 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 321 ms | 312 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 479 ms | 466 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 922 ms | 900 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 12.7 ms | 12.4 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 85.5 ms | 83.6 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 825 ms | 807 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 848 ms | 830 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 47.2 ms | 46.3 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 21.6 ms | 21.2 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 473 ms | 464 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 284 ms | 279 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 77.5 ms | 76.3 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 44.1 ms | 43.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 273 ms | 269 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 10.9 ms | 10.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 484 ms | 479 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 18.5 ms | 18.4 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 254 ms | 252 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 318 ms | 316 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 48.4 ms | 48.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 12.5 ms | 12.4 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.02x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (5): react_agent_10x_sync, react_agent_10x, wide_state_15x600, react_agent_1

from langgraph.prebuilt.chain_executor import create_chain_executor
from langgraph.prebuilt.chain import create_chain
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langgraph.prebuilt.tool_node import (
Expand All @@ -13,7 +13,7 @@

__all__ = [
"create_react_agent",
"create_chain_executor",
"create_chain",
"ToolExecutor",
"ToolInvocation",
"ToolNode",
Expand Down
64 changes: 64 additions & 0 deletions libs/langgraph/langgraph/prebuilt/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, Optional, Type, Union, cast

from langchain_core.runnables.base import Runnable, RunnableLike

from langgraph.graph.state import END, START, StateGraph


def _get_step_name(step: RunnableLike) -> str:
if isinstance(step, Runnable):
return step.get_name()
elif callable(step):
return getattr(step, "__name__", step.__class__.__name__)
else:
raise TypeError(f"Unsupported step type: {step}")


def create_chain(
*steps: Union[RunnableLike, tuple[str, RunnableLike]],
state_schema: Type[Any],
input_schema: Optional[Type[Any]] = None,
output_schema: Optional[Type[Any]] = None,
) -> StateGraph:
"""Creates a chain graph that runs a series of provided steps in order.
Args:
*steps: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples.
If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name).
Each step will be executed in the order provided.
state_schema: The state schema for the graph.
input_schema: The input schema for the graph.
output_schema: The output schema for the graph. Will only be used when calling `graph.invoke()`.
Returns:
A StateGraph object.
"""
if len(steps) < 1:
raise ValueError("Chain requires at least one step.")

node_names = set()
builder = StateGraph(state_schema, input=input_schema, output=output_schema)
previous_name: Optional[str] = None
for step in steps:
if isinstance(step, tuple) and len(step) == 2:
name, step = step
else:
name = _get_step_name(step)

if name in node_names:
raise ValueError(
f"Step names must be unique: step with the name '{name}' already exists. "
"If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
)

node_names.add(name)
builder.add_node(name, step)
if previous_name is None:
builder.add_edge(START, name)
else:
builder.add_edge(previous_name, name)

previous_name = name

builder.add_edge(cast(str, previous_name), END)
return builder
89 changes: 0 additions & 89 deletions libs/langgraph/langgraph/prebuilt/chain_executor.py

This file was deleted.

73 changes: 66 additions & 7 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
from langgraph.prebuilt import (
ToolNode,
ValidationNode,
create_chain_executor,
create_chain,
create_react_agent,
tools_condition,
)
from langgraph.prebuilt.chain import _get_step_name
from langgraph.prebuilt.tool_node import (
TOOL_CALL_ERROR_TEMPLATE,
InjectedState,
Expand Down Expand Up @@ -1352,7 +1353,39 @@ def foo(a: str, b: int) -> float:
assert _get_state_args(foo) == {"a": None, "b": "bar"}


def test_chain_executor():
def test__get_step_name() -> None:
# default runnable name
assert _get_step_name(RunnableLambda(func=lambda x: x)) == "RunnableLambda"
# custom runnable name
assert (
_get_step_name(RunnableLambda(name="my_runnable", func=lambda x: x))
== "my_runnable"
)

# lambda
assert _get_step_name(lambda x: x) == "<lambda>"

# regular function
def func(state):
return

assert _get_step_name(func) == "func"

class MyClass:
def __call__(self, state):
return

def class_method(self, state):
return

# callable class
assert _get_step_name(MyClass()) == "MyClass"

# class method
assert _get_step_name(MyClass().class_method) == "class_method"


def test_chain():
class State(TypedDict):
foo: Annotated[list[str], operator.add]
bar: str
Expand All @@ -1365,17 +1398,18 @@ def step2(state: State):

# test raising if less than 1 steps
with pytest.raises(ValueError):
create_chain_executor(state_schema=State)
create_chain(state_schema=State)

# test raising if duplicate step names
with pytest.raises(ValueError):
create_chain_executor(step1, step1, state_schema=State)
create_chain(step1, step1, state_schema=State)

with pytest.raises(ValueError):
create_chain_executor(("foo", step1), ("foo", step1), state_schema=State)
create_chain(("foo", step1), ("foo", step1), state_schema=State)

# test unnamed steps
executor = create_chain_executor(step1, step2, state_schema=State)
builder = create_chain(step1, step2, state_schema=State)
executor = builder.compile()
result = executor.invoke({"foo": []})
assert result == {"foo": ["step1", "step2"], "bar": "baz"}
stream_chunks = list(executor.stream({"foo": []}))
Expand All @@ -1385,13 +1419,38 @@ def step2(state: State):
]

# test named steps
executor_named_steps = create_chain_executor(
builder_named_steps = create_chain(
("meow1", step1), ("meow2", step2), state_schema=State
)
executor_named_steps = builder_named_steps.compile()
result = executor_named_steps.invoke({"foo": []})
stream_chunks = list(executor_named_steps.stream({"foo": []}))
assert result == {"foo": ["step1", "step2"], "bar": "baz"}
assert stream_chunks == [
{"meow1": {"foo": ["step1"], "bar": "baz"}},
{"meow2": {"foo": ["step2"]}},
]

# test input/output schema & functions w/ duplicate names
class Input(TypedDict):
foo: Annotated[list[str], operator.add]

class Output(TypedDict):
bar: str

builder_named_steps = create_chain(
("meow1", lambda state: {"foo": ["foo"]}),
("meow2", lambda state: {"bar": state["foo"][0] + "bar"}),
state_schema=State,
input_schema=Input,
output_schema=Output,
)
executor_named_steps = builder_named_steps.compile()
result = executor_named_steps.invoke({"foo": []})
stream_chunks = list(executor_named_steps.stream({"foo": []}))
# filtered by output schema
assert result == {"bar": "foobar"}
assert stream_chunks == [
{"meow1": {"foo": ["foo"]}},
{"meow2": {"bar": "foobar"}},
]

0 comments on commit 2029765

Please sign in to comment.