Skip to content

Commit

Permalink
update tests & update/aupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Nov 4, 2024
1 parent c03f1c2 commit 902d5c6
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 10 deletions.
8 changes: 8 additions & 0 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
CONF,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_METADATA,
CONFIG_KEY_NODE_FINISHED,
CONFIG_KEY_READ,
CONFIG_KEY_RESUMING,
Expand Down Expand Up @@ -826,14 +827,17 @@ def update_state(
config,
{CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
)
checkpoint_metadata = config.get(CONFIG_KEY_METADATA, {})
if saved:
checkpoint_config = patch_configurable(config, saved.config[CONF])
checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
# find last node that updated the state, if not provided
if values is None and as_node is None:
next_config = checkpointer.put(
checkpoint_config,
create_checkpoint(checkpoint, None, step),
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
Expand Down Expand Up @@ -922,6 +926,7 @@ def update_state(
checkpoint_config,
checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {as_node: values},
Expand Down Expand Up @@ -978,14 +983,17 @@ async def aupdate_state(
config,
{CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
)
checkpoint_metadata = config.get(CONFIG_KEY_METADATA, {})
if saved:
checkpoint_config = patch_configurable(config, saved.config[CONF])
checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
# find last node that updated the state, if not provided
if values is None and as_node is None:
next_config = await checkpointer.aput(
checkpoint_config,
create_checkpoint(checkpoint, None, step),
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
Expand Down
2 changes: 2 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_no_modifier(request: pytest.FixtureRequest, checkpointer_name: str) ->
"source": "loop",
"writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}},
"step": 1,
"thread_id": "123",
}
assert saved.pending_writes == []

Expand Down Expand Up @@ -189,6 +190,7 @@ async def test_no_modifier_async(checkpointer_name: str) -> None:
"source": "loop",
"writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}},
"step": 1,
"thread_id": "123",
}
assert saved.pending_writes == []

Expand Down
Loading

0 comments on commit 902d5c6

Please sign in to comment.