Skip to content

Commit

Permalink
Merge pull request #1826 from langchain-ai/nc/24sep/get-subgraphs-nam…
Browse files Browse the repository at this point in the history
…espace

Add namesapce filter to get_subgraphs
  • Loading branch information
nfcampos authored Sep 24, 2024
2 parents 5235ad0 + 8205c85 commit 2f7826e
Showing 1 changed file with 75 additions and 63 deletions.
138 changes: 75 additions & 63 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,23 @@ def stream_channels_asis(self) -> Union[str, Sequence[str]]:
k for k in self.channels if isinstance(self.channels[k], BaseChannel)
]

def get_subgraphs(self, recurse: bool = False) -> Iterator[tuple[str, Pregel]]:
def get_subgraphs(
self, *, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, Pregel]]:
for name, node in self.nodes.items():
# filter by prefix
if namespace is not None:
if not namespace.startswith(name):
continue
# find the subgraph, if any
graph: Optional[Pregel] = None
candidates = [node.bound]
for candidate in candidates:
if isinstance(candidate, Pregel):
if (
isinstance(candidate, Pregel)
# subgraphs that disabled checkpointing are not considered
and candidate.checkpointer is not False
):
graph = candidate
break
elif isinstance(candidate, RunnableSequence):
Expand All @@ -408,17 +418,25 @@ def get_subgraphs(self, recurse: bool = False) -> Iterator[tuple[str, Pregel]]:
)
# if found, yield recursively
if graph:
yield name, graph
if name == namespace:
yield name, graph
return # we found it, stop searching
if namespace is None:
yield name, graph
if recurse:
if namespace is not None:
namespace = namespace[len(name) + 1 :]
yield from (
(f"{name}{NS_SEP}{n}", s)
for n, s in graph.get_subgraphs(recurse=recurse)
for n, s in graph.get_subgraphs(
namespace=namespace, recurse=recurse
)
)

async def aget_subgraphs(
self, recurse: bool = False
self, *, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, Pregel]]:
for name, node in self.get_subgraphs(recurse=recurse):
for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse):
yield name, node

def _prepare_state_snapshot(
Expand Down Expand Up @@ -588,14 +606,13 @@ def get_state(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for name, pregel in self.get_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
return pregel.get_state(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
subgraphs=subgraphs,
)
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return pregel.get_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
subgraphs=subgraphs,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down Expand Up @@ -623,14 +640,13 @@ async def aget_state(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for name, pregel in self.aget_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
return await pregel.aget_state(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
subgraphs=subgraphs,
)
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return await pregel.aget_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
subgraphs=subgraphs,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down Expand Up @@ -663,17 +679,16 @@ def get_state_history(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for name, pregel in self.get_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
yield from pregel.get_state_history(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
filter=filter,
before=before,
limit=limit,
)
return
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
yield from pregel.get_state_history(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
filter=filter,
before=before,
limit=limit,
)
return
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down Expand Up @@ -713,18 +728,17 @@ async def aget_state_history(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for name, pregel in self.aget_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
async for state in pregel.aget_state_history(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
filter=filter,
before=before,
limit=limit,
):
yield state
return
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
async for state in pregel.aget_state_history(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
filter=filter,
before=before,
limit=limit,
):
yield state
return
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down Expand Up @@ -769,15 +783,14 @@ def update_state(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for name, pregel in self.get_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
return pregel.update_state(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
values,
as_node,
)
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return pregel.update_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
values,
as_node,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down Expand Up @@ -917,15 +930,14 @@ async def aupdate_state(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for name, pregel in self.aget_subgraphs(recurse=True):
if name == recast_checkpoint_ns:
return await pregel.aupdate_state(
patch_configurable(
config, {CONFIG_KEY_CHECKPOINTER: checkpointer}
),
values,
as_node,
)
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return await pregel.aupdate_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
values,
as_node,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

Expand Down

0 comments on commit 2f7826e

Please sign in to comment.