Skip to content

Commit

Permalink
langgraph: fix caller_ns in remote graph (#2250)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Oct 31, 2024
1 parent 3641e65 commit 0718d0a
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_STREAM,
INTERRUPT,
NS_SEP,
)
from langgraph.errors import GraphInterrupt
from langgraph.pregel.protocol import PregelProtocol
Expand Down Expand Up @@ -584,12 +585,13 @@ def stream(
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
):
if "|" in chunk.event:
mode, ns_ = chunk.event.split("|", 1)
ns = tuple(ns_.split("|"))
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
if stream is not None and chunk.event in stream.modes:
stream((ns, mode, chunk.data))
Expand All @@ -600,12 +602,12 @@ def stream(
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if chunk.event.split("|", 1)[0] not in stream_modes:
if chunk.event.split(NS_SEP, 1)[0] not in stream_modes:
continue
if subgraphs:
if "|" in chunk.event:
mode, ns_ = chunk.event.split("|", 1)
ns = tuple(ns_.split("|"))
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if req_single:
Expand Down Expand Up @@ -666,12 +668,13 @@ async def astream(
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
):
if "|" in chunk.event:
mode, ns_ = chunk.event.split("|", 1)
ns = tuple(ns_.split("|"))
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
if stream is not None and chunk.event in stream.modes:
stream((ns, mode, chunk.data))
Expand All @@ -682,7 +685,7 @@ async def astream(
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if chunk.event.split("|", 1)[0] not in stream_modes:
if chunk.event.split(NS_SEP, 1)[0] not in stream_modes:
continue
if subgraphs:
if req_single:
Expand Down

0 comments on commit 0718d0a

Please sign in to comment.