Skip to content

Commit

Permalink
Remove node output schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jul 18, 2024
1 parent 7bc489f commit b5b0f8d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 47 deletions.
56 changes: 25 additions & 31 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class StateNodeSpec(NamedTuple):
runnable: Runnable
metadata: dict[str, Any]
input: Type[Any]
output: Type[Any]


class StateGraph(Graph):
Expand Down Expand Up @@ -136,6 +135,7 @@ def __init__(
if state_schema is None:
if input is None or output is None:
raise ValueError("Must provide state_schema or input and output")
state_schema = input
else:
if input is None:
input = state_schema
Expand Down Expand Up @@ -167,10 +167,12 @@ def _add_schema(self, schema: Type[Any]) -> None:
for key, channel in channels.items():
if key in self.channels:
if self.channels[key] != channel:
print(self.channels[key], channel)
raise ValueError(
f"Channel '{key}' already exists with a different type"
)
if isinstance(channel, LastValue):
pass
else:
raise ValueError(
f"Channel '{key}' already exists with a different type"
)
else:
self.channels[key] = channel
for key, managed in managed.items():
Expand All @@ -193,7 +195,6 @@ def add_node(
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
output: Optional[Type[Any]] = None,
) -> None:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand All @@ -217,7 +218,6 @@ def add_node(
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
output: Optional[Type[Any]] = None,
) -> None:
"""Adds a new node to the state graph.
Expand All @@ -240,7 +240,6 @@ def add_node(
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
output: Optional[Type[Any]] = None,
) -> None:
"""Adds a new node to the state graph.
Expand All @@ -249,6 +248,8 @@ def add_node(
Args:
node (Union[str, RunnableLike)]: The function or runnable this node will run.
action (Optional[RunnableLike]): The action associated with the node. (default: None)
metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None)
input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema)
Raises:
ValueError: If the key is already being used as a state key.
Expand Down Expand Up @@ -313,21 +314,14 @@ def add_node(
input_hint = hints[list(hints.keys())[0]]
if isinstance(input_hint, type) and get_type_hints(input_hint):
input = input_hint
if output is None:
output_hint = hints.get("return", Any)
if isinstance(output_hint, type) and get_type_hints(output_hint):
output = output_hint
except TypeError:
pass
if input is not None:
self._add_schema(input)
if output is not None:
self._add_schema(output)
self.nodes[node] = StateNodeSpec(
coerce_to_runnable(action, name=node, trace=False),
metadata,
input=input or self.schema,
output=output or self.schema,
)

def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None:
Expand Down Expand Up @@ -482,24 +476,17 @@ def get_output_schema(

def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
if key == START:
input_schema = self.builder.input
output_keys = [
k
for k, v in self.builder.schemas[self.builder.input].items()
if not isinstance(v, Context) and not is_managed_value(v)
]
else:
input_schema = node.input if node else self.builder.schema
input_values = {
k: v if is_managed_value(v) else k
for k, v in self.builder.schemas[input_schema].items()
}
is_single_input = len(input_values) == 1 and "__root__" in input_values

output_keys = [
k
for k, v in self.builder.schemas[
node.output if node else self.builder.schema
].items()
if not is_managed_value(v)
]
output_keys = list(self.builder.channels)

def _get_state_key(input: dict, config: RunnableConfig, *, key: str) -> Any:
def _get_state_key(
input: Union[None, dict, Any], config: RunnableConfig, *, key: str
) -> Any:
if input is None:
return SKIP_WRITE
elif isinstance(input, dict):
Expand Down Expand Up @@ -540,6 +527,13 @@ def _get_state_key(input: dict, config: RunnableConfig, *, key: str) -> Any:
],
)
else:
input_schema = node.input if node else self.builder.schema
input_values = {
k: v if is_managed_value(v) else k
for k, v in self.builder.schemas[input_schema].items()
}
is_single_input = len(input_values) == 1 and "__root__" in input_values

self.channels[key] = EphemeralValue(Any, guard=False)
self.nodes[key] = PregelNode(
triggers=[],
Expand Down
35 changes: 28 additions & 7 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,17 @@ def logic(inp: str) -> str:
graph.invoke("", {"configurable": {"thread_id": "thread-1"}})


def test_node_schemas() -> None:
def test_node_schemas_custom_output() -> None:
from langchain_core.messages import HumanMessage

class State(TypedDict):
hello: str
bye: str
messages: Annotated[list[str], add_messages]

class Output(TypedDict):
messages: list[str]

class StateForA(TypedDict):
hello: str
messages: Annotated[list[str], add_messages]
Expand All @@ -254,14 +257,14 @@ class StateForB(TypedDict):
bye: str
now: int

def node_b(state: StateForB) -> StateForB:
def node_b(state: StateForB):
assert state == {
"bye": "world",
"now": None,
}
return {
"now": 123,
"hello": "again", # ignored because not in output schema
"hello": "again",
}

class StateForC(TypedDict):
Expand All @@ -270,11 +273,11 @@ class StateForC(TypedDict):

def node_c(state: StateForC) -> StateForC:
assert state == {
"hello": "there",
"hello": "again",
"now": 123,
}

builder = StateGraph(State)
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
Expand All @@ -284,8 +287,26 @@ def node_c(state: StateForC) -> StateForC:
graph = builder.compile()

assert graph.invoke({"hello": "there", "bye": "world", "messages": "hello"}) == {
"hello": "there",
"bye": "world",
"messages": [HumanMessage(content="hello", id=AnyStr())],
}

builder = StateGraph(input=State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()

assert graph.invoke(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
) == {
"messages": [HumanMessage(content="hello", id=AnyStr())],
}

Expand Down
39 changes: 30 additions & 9 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,22 @@ async def alittlewhile(input: State) -> None:
await checkpointer.__aexit__(None, None, None)


async def test_node_schemas() -> None:
async def test_node_schemas_custom_output() -> None:
from langchain_core.messages import HumanMessage

class State(TypedDict):
hello: str
bye: str
messages: Annotated[list[str], add_messages]

class Output(TypedDict):
messages: list[str]

class StateForA(TypedDict):
hello: str
messages: Annotated[list[str], add_messages]

async def node_a(state: StateForA) -> State:
async def node_a(state: StateForA):
assert state == {
"hello": "there",
"messages": [HumanMessage(content="hello", id=AnyStr())],
Expand All @@ -411,27 +414,27 @@ class StateForB(TypedDict):
bye: str
now: int

async def node_b(state: StateForB) -> StateForB:
async def node_b(state: StateForB):
assert state == {
"bye": "world",
"now": None,
}
return {
"now": 123,
"hello": "again", # ignored because not in output schema
"hello": "again",
}

class StateForC(TypedDict):
hello: str
now: int

async def node_c(state: StateForC) -> StateForC:
async def node_c(state: StateForC):
assert state == {
"hello": "there",
"hello": "again",
"now": 123,
}

builder = StateGraph(State)
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
Expand All @@ -443,8 +446,26 @@ async def node_c(state: StateForC) -> StateForC:
assert await graph.ainvoke(
{"hello": "there", "bye": "world", "messages": "hello"}
) == {
"hello": "there",
"bye": "world",
"messages": [HumanMessage(content="hello", id=AnyStr())],
}

builder = StateGraph(input=State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()

assert await graph.ainvoke(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
) == {
"messages": [HumanMessage(content="hello", id=AnyStr())],
}

Expand Down

0 comments on commit b5b0f8d

Please sign in to comment.