Skip to content

Commit

Permalink
Merge pull request #2048 from langchain-ai/dqbd/debug-tasks-state
Browse files Browse the repository at this point in the history
feat(debug): send tasks info
  • Loading branch information
dqbd authored Oct 9, 2024
2 parents 28ff7fd + a53a566 commit 5d0afa3
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 45 deletions.
33 changes: 4 additions & 29 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from langchain_core.globals import get_debug
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableSequence,
)
from langchain_core.runnables.base import Input, Output
Expand All @@ -37,7 +36,6 @@
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_function_nonlocals,
get_unique_config_specs,
)
from langchain_core.tracers._streaming import _StreamingCallbackHandler
Expand Down Expand Up @@ -86,7 +84,7 @@
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.runner import PregelRunner
from langgraph.pregel.utils import get_new_channel_versions
from langgraph.pregel.utils import find_subgraph_pregel, get_new_channel_versions
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
Expand All @@ -100,7 +98,6 @@
)
from langgraph.utils.pydantic import create_model
from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined]
from langgraph.utils.runnable import RunnableCallable

WriteValue = Union[Callable[[Input], Output], Any]

Expand Down Expand Up @@ -391,32 +388,10 @@ def get_subgraphs(
if namespace is not None:
if not namespace.startswith(name):
continue

# find the subgraph, if any
graph: Optional[Pregel] = None
candidates = [node.bound]
for candidate in candidates:
if (
isinstance(candidate, Pregel)
# subgraphs that disabled checkpointing are not considered
and candidate.checkpointer is not False
):
graph = candidate
break
elif isinstance(candidate, RunnableSequence):
candidates.extend(candidate.steps)
elif isinstance(candidate, RunnableLambda):
candidates.extend(candidate.deps)
elif isinstance(candidate, RunnableCallable):
if candidate.func is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(candidate.func)
)
if candidate.afunc is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(candidate.afunc)
)
graph = cast(Optional[Pregel], find_subgraph_pregel(node.bound))

# if found, yield recursively
if graph:
if name == namespace:
Expand Down
37 changes: 35 additions & 2 deletions libs/langgraph/langgraph/pregel/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@

from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, PendingWrite
from langgraph.constants import ERROR, INTERRUPT, TAG_HIDDEN
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
ERROR,
INTERRUPT,
NS_END,
NS_SEP,
TAG_HIDDEN,
)
from langgraph.pregel.io import read_channels
from langgraph.pregel.utils import find_subgraph_pregel
from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot


Expand All @@ -45,6 +54,7 @@ class CheckpointTask(TypedDict):
name: str
error: Optional[str]
interrupts: list[dict]
state: Optional[RunnableConfig]


class CheckpointPayload(TypedDict):
Expand Down Expand Up @@ -140,6 +150,27 @@ def map_debug_checkpoint(
parent_config: Optional[RunnableConfig],
) -> Iterator[DebugOutputCheckpoint]:
"""Produce "checkpoint" events for stream_mode=debug."""

parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {}

for task in tasks:
if not find_subgraph_pregel(task.proc):
continue

# assemble checkpoint_ns for this task
task_ns = f"{task.name}{NS_END}{task.id}"
if parent_ns:
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"

# set config as signal that subgraph checkpoints exist
task_states[task.id] = {
CONF: {
"thread_id": config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}

yield {
"type": "checkpoint",
"timestamp": checkpoint["ts"],
Expand All @@ -155,14 +186,16 @@ def map_debug_checkpoint(
"id": t.id,
"name": t.name,
"error": t.error,
"state": t.state,
}
if t.error
else {
"id": t.id,
"name": t.name,
"interrupts": tuple(asdict(i) for i in t.interrupts),
"state": t.state,
}
for t in tasks_w_writes(tasks, pending_writes, None)
for t in tasks_w_writes(tasks, pending_writes, task_states)
],
},
}
Expand Down
37 changes: 37 additions & 0 deletions libs/langgraph/langgraph/pregel/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from typing import Optional

from langchain_core.runnables import RunnableLambda, RunnableSequence
from langchain_core.runnables.utils import get_function_nonlocals

from langgraph.checkpoint.base import ChannelVersions
from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq


def get_new_channel_versions(
Expand All @@ -17,3 +23,34 @@ def get_new_channel_versions(
new_versions = current_versions

return new_versions


def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]:
from langgraph.pregel import Pregel

candidates: list[Runnable] = [candidate]

for c in candidates:
if (
isinstance(c, Pregel)
# subgraphs that disabled checkpointing are not considered
and c.checkpointer is not False
):
return c
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
candidates.extend(c.steps)
elif isinstance(c, RunnableLambda):
candidates.extend(c.deps)
elif isinstance(c, RunnableCallable):
if c.func is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.func)
)
if c.afunc is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.afunc)
)

return None
2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph"
version = "0.2.34"
version = "0.2.35"
description = "Building stateful, multi-actor applications with LLMs"
authors = []
license = "MIT"
Expand Down
Loading

0 comments on commit 5d0afa3

Please sign in to comment.