diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 666560024..d69fbc75f 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -941,7 +941,7 @@ def update_state( if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() - task = PregelTaskWrites(as_node, writes, [INTERRUPT]) + task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task @@ -1121,7 +1121,7 @@ async def aupdate_state( if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() - task = PregelTaskWrites(as_node, writes, [INTERRUPT]) + task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 4a2455aa1..af71294ae 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -67,6 +67,9 @@ class WritesProtocol(Protocol): """Protocol for objects containing writes to be applied to checkpoint. Implemented by PregelTaskWrites and PregelExecutableTask.""" + @property + def path(self) -> tuple[Union[str, int], ...]: ... + @property def name(self) -> str: ... @@ -81,6 +84,7 @@ class PregelTaskWrites(NamedTuple): """Simplest implementation of WritesProtocol, for usage with writes that don't originate from a runnable task, eg. graph input, update_state, etc.""" + path: tuple[Union[str, int], ...] name: str writes: Sequence[tuple[str, Any]] triggers: Sequence[str] @@ -190,6 +194,9 @@ def apply_writes( """Apply writes from a set of tasks (usually the tasks from a Pregel step) to the checkpoint and channels, and return managed values writes to be applied externally.""" + # sort tasks on path + tasks = sorted(tasks, key=lambda t: t.path) + # update seen versions for task in tasks: checkpoint["versions_seen"].setdefault(task.name, {}).update( @@ -444,7 +451,9 @@ def prepare_single_task( checkpoint, channels, managed, - PregelTaskWrites(packet.node, writes, triggers), + PregelTaskWrites( + task_path, packet.node, writes, triggers + ), config, ), CONFIG_KEY_STORE: ( @@ -552,7 +561,7 @@ def prepare_single_task( checkpoint, channels, managed, - PregelTaskWrites(name, writes, triggers), + PregelTaskWrites(task_path, name, writes, triggers), config, ), CONFIG_KEY_STORE: ( diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 14254da75..4cfd550fc 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -477,7 +477,10 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: mv_writes = apply_writes( self.checkpoint, self.channels, - [*discard_tasks.values(), PregelTaskWrites(INPUT, input_writes, [])], + [ + *discard_tasks.values(), + PregelTaskWrites((), INPUT, input_writes, []), + ], self.checkpointer_get_next_version, ) assert not mv_writes, "Can't write to SharedValues in graph input" diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index dd3a9f714..0ec87b881 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1787,11 +1787,11 @@ def route_to_three(state) -> Literal["3"]: "0", "1", "1.1", + "3.1", "2|1", "2|2", "2|3", "2|4", - "3.1", "3", ] @@ -1836,13 +1836,13 @@ def route_to_three(state) -> Literal["3"]: assert graph.invoke(["0"]) == [ "0", "1", + "3.1", "2|Control(send=Send(node='2', arg=3))", "2|Control(send=Send(node='2', arg=4))", - "3.1", + "3", "2|3", "2|4", "3", - "3", ] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 87d2679e3..b73f0a5ee 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2004,11 +2004,11 @@ async def route_to_three(state) -> Literal["3"]: "0", "1", "1.1", + "3.1", "2|1", "2|2", "2|3", "2|4", - "3.1", "3", ] @@ -2053,13 +2053,13 @@ async def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(["0"]) == [ "0", "1", + "3.1", "2|Control(send=Send(node='2', arg=3))", "2|Control(send=Send(node='2', arg=4))", - "3.1", + "3", "2|3", "2|4", "3", - "3", ]