Skip to content

Commit

Permalink
Merge pull request #2333 from langchain-ai/nc/apply-writes-order
Browse files Browse the repository at this point in the history
lib: Enforce write application order in apply_writes
  • Loading branch information
nfcampos authored Nov 5, 2024
2 parents 90639e6 + dec0b7f commit 36e6b89
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ def update_state(
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites(as_node, writes, [INTERRUPT])
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
Expand Down Expand Up @@ -1121,7 +1121,7 @@ async def aupdate_state(
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites(as_node, writes, [INTERRUPT])
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
Expand Down
13 changes: 11 additions & 2 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class WritesProtocol(Protocol):
"""Protocol for objects containing writes to be applied to checkpoint.
Implemented by PregelTaskWrites and PregelExecutableTask."""

@property
def path(self) -> tuple[Union[str, int], ...]: ...

@property
def name(self) -> str: ...

Expand All @@ -81,6 +84,7 @@ class PregelTaskWrites(NamedTuple):
"""Simplest implementation of WritesProtocol, for usage with writes that
don't originate from a runnable task, eg. graph input, update_state, etc."""

path: tuple[Union[str, int], ...]
name: str
writes: Sequence[tuple[str, Any]]
triggers: Sequence[str]
Expand Down Expand Up @@ -190,6 +194,9 @@ def apply_writes(
"""Apply writes from a set of tasks (usually the tasks from a Pregel step)
to the checkpoint and channels, and return managed values writes to be applied
externally."""
# sort tasks on path
tasks = sorted(tasks, key=lambda t: t.path)

# update seen versions
for task in tasks:
checkpoint["versions_seen"].setdefault(task.name, {}).update(
Expand Down Expand Up @@ -444,7 +451,9 @@ def prepare_single_task(
checkpoint,
channels,
managed,
PregelTaskWrites(packet.node, writes, triggers),
PregelTaskWrites(
task_path, packet.node, writes, triggers
),
config,
),
CONFIG_KEY_STORE: (
Expand Down Expand Up @@ -552,7 +561,7 @@ def prepare_single_task(
checkpoint,
channels,
managed,
PregelTaskWrites(name, writes, triggers),
PregelTaskWrites(task_path, name, writes, triggers),
config,
),
CONFIG_KEY_STORE: (
Expand Down
5 changes: 4 additions & 1 deletion libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,10 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
mv_writes = apply_writes(
self.checkpoint,
self.channels,
[*discard_tasks.values(), PregelTaskWrites(INPUT, input_writes, [])],
[
*discard_tasks.values(),
PregelTaskWrites((), INPUT, input_writes, []),
],
self.checkpointer_get_next_version,
)
assert not mv_writes, "Can't write to SharedValues in graph input"
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,11 +1787,11 @@ def route_to_three(state) -> Literal["3"]:
"0",
"1",
"1.1",
"3.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3.1",
"3",
]

Expand Down Expand Up @@ -1836,13 +1836,13 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"]) == [
"0",
"1",
"3.1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3.1",
"3",
"2|3",
"2|4",
"3",
"3",
]


Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,11 +2004,11 @@ async def route_to_three(state) -> Literal["3"]:
"0",
"1",
"1.1",
"3.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3.1",
"3",
]

Expand Down Expand Up @@ -2053,13 +2053,13 @@ async def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"3.1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3.1",
"3",
"2|3",
"2|4",
"3",
"3",
]


Expand Down

0 comments on commit 36e6b89

Please sign in to comment.