Skip to content

Commit

Permalink
Merge pull request #2329 from langchain-ai/nc/4nov/get-state-apply-pe…
Browse files Browse the repository at this point in the history
…nding-writes

lib: In calls to get_state apply pending writes
  • Loading branch information
nfcampos authored Nov 4, 2024
2 parents 5e4c928 + de3b654 commit 1b85764
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
35 changes: 31 additions & 4 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_NODE_FINISHED,
Expand Down Expand Up @@ -439,6 +440,7 @@ def _prepare_state_snapshot(
config: RunnableConfig,
saved: Optional[CheckpointTuple],
recurse: Optional[BaseCheckpointSaver] = None,
apply_pending_writes: bool = False,
) -> StateSnapshot:
if not saved:
return StateSnapshot(
Expand Down Expand Up @@ -469,7 +471,10 @@ def _prepare_state_snapshot(
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=False,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# get the subgraphs
subgraphs = dict(self.get_subgraphs())
Expand Down Expand Up @@ -503,6 +508,12 @@ def _prepare_state_snapshot(
task_states[task.id] = subgraphs[task.name].get_state(
config, subgraphs=True
)
# apply pending writes
if apply_pending_writes and saved.pending_writes:
for tid, *t in saved.pending_writes:
next_tasks[tid].writes.append(t) # type: ignore[arg-type]
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(saved.checkpoint, channels, tasks, None)
# assemble the state snapshot
return StateSnapshot(
read_channels(channels, self.stream_channels_asis),
Expand All @@ -524,6 +535,7 @@ async def _aprepare_state_snapshot(
config: RunnableConfig,
saved: Optional[CheckpointTuple],
recurse: Optional[BaseCheckpointSaver] = None,
apply_pending_writes: bool = False,
) -> StateSnapshot:
if not saved:
return StateSnapshot(
Expand Down Expand Up @@ -557,7 +569,10 @@ async def _aprepare_state_snapshot(
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=False,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# get the subgraphs
subgraphs = {n: g async for n, g in self.aget_subgraphs()}
Expand Down Expand Up @@ -591,6 +606,12 @@ async def _aprepare_state_snapshot(
task_states[task.id] = await subgraphs[task.name].aget_state(
config, subgraphs=True
)
# apply pending writes
if apply_pending_writes and saved.pending_writes:
for tid, *t in saved.pending_writes:
next_tasks[tid].writes.append(t) # type: ignore[arg-type]
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(saved.checkpoint, channels, tasks, None)
# assemble the state snapshot
return StateSnapshot(
read_channels(channels, self.stream_channels_asis),
Expand Down Expand Up @@ -638,7 +659,10 @@ def get_state(
config = merge_configs(self.config, config) if self.config else config
saved = checkpointer.get_tuple(config)
return self._prepare_state_snapshot(
config, saved, recurse=checkpointer if subgraphs else None
config,
saved,
recurse=checkpointer if subgraphs else None,
apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
)

async def aget_state(
Expand Down Expand Up @@ -672,7 +696,10 @@ async def aget_state(
config = merge_configs(self.config, config) if self.config else config
saved = await checkpointer.aget_tuple(config)
return await self._aprepare_state_snapshot(
config, saved, recurse=checkpointer if subgraphs else None
config,
saved,
recurse=checkpointer if subgraphs else None,
apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
)

def get_state_history(
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,9 +1513,10 @@ def reset(self):
assert two.calls == 2 # two attempts

# latest checkpoint should be before nodes "one", "two"
# but we should have applied the write from "one"
state = graph.get_state(thread1)
assert state is not None
assert state.values == {"value": 1}
assert state.values == {"value": 3}
assert state.next == ("one", "two")
assert state.tasks == (
PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}),
Expand All @@ -1528,6 +1529,10 @@ def reset(self):
"writes": None,
"thread_id": "1",
}
# get_state with checkpoint_id should not apply any pending writes
state = graph.get_state(state.config)
assert state is not None
assert state.values == {"value": 1}
# should contain pending write of "one"
checkpoint = checkpointer.get_tuple(thread1)
assert checkpoint is not None
Expand Down
12 changes: 10 additions & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,11 @@ async def alittlewhile(input: State) -> None:
assert awhile.started is False

# checkpoint with output of "alittlewhile" should not be saved
# but we should have applied pending writes
if checkpointer is not None:
state = await graph.aget_state(thread1)
assert state is not None
assert state.values == {"value": 1}
assert state.values == {"value": 3} # 1 + 2
assert state.next == (
"aparallelwhile",
"alittlewhile",
Expand Down Expand Up @@ -1722,9 +1723,10 @@ def reset(self):
assert two.calls == 2

# latest checkpoint should be before nodes "one", "two"
# but we should have applied pending writes from "one"
state = await graph.aget_state(thread1)
assert state is not None
assert state.values == {"value": 1}
assert state.values == {"value": 3}
assert state.next == ("one", "two")
assert state.tasks == (
PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}),
Expand All @@ -1742,6 +1744,10 @@ def reset(self):
"writes": None,
"thread_id": "1",
}
# get_state with checkpoint_id should not apply any pending writes
state = await graph.aget_state(state.config)
assert state is not None
assert state.values == {"value": 1}
# should contain pending write of "one"
checkpoint = await checkpointer.aget_tuple(thread1)
assert checkpoint is not None
Expand Down Expand Up @@ -2117,6 +2123,8 @@ async def route_to_three(state) -> Literal["3"]:
thread1 = {"max_concurrency": 10, "configurable": {"thread_id": "1"}}

assert await graph.ainvoke(["0"], thread1) == ["0", "1"]
state = await graph.aget_state(thread1)
assert state.values == ["0", "1"]
assert await graph.ainvoke(None, thread1) == ["0", "1", *range(100), "3"]


Expand Down

0 comments on commit 1b85764

Please sign in to comment.