diff --git a/libs/langgraph/langgraph/managed/__init__.py b/libs/langgraph/langgraph/managed/__init__.py index 2c101e400..966348e6f 100644 --- a/libs/langgraph/langgraph/managed/__init__.py +++ b/libs/langgraph/langgraph/managed/__init__.py @@ -1,3 +1,3 @@ -from langgraph.managed.is_last_step import IsLastStep +from langgraph.managed.is_last_step import IsLastStep, RemainingSteps -__all__ = ["IsLastStep"] +__all__ = ["IsLastStep", "RemainingSteps"] diff --git a/libs/langgraph/langgraph/managed/base.py b/libs/langgraph/langgraph/managed/base.py index a8fc27c8e..36962e156 100644 --- a/libs/langgraph/langgraph/managed/base.py +++ b/libs/langgraph/langgraph/managed/base.py @@ -13,22 +13,23 @@ Union, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import Self, TypeGuard +from langgraph.types import LoopProtocol + V = TypeVar("V") U = TypeVar("U") class ManagedValue(ABC, Generic[V]): - def __init__(self, config: RunnableConfig) -> None: - self.config = config + def __init__(self, loop: LoopProtocol) -> None: + self.loop = loop @classmethod @contextmanager - def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: + def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: try: - value = cls(config, **kwargs) + value = cls(loop, **kwargs) yield value finally: # because managed value and Pregel have reference to each other @@ -40,9 +41,9 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: @classmethod @asynccontextmanager - async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]: + async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: try: - value = cls(config, **kwargs) + value = cls(loop, **kwargs) yield value finally: # because managed value and Pregel have reference to each other @@ -53,7 +54,7 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se pass @abstractmethod - def __call__(self, step: int) -> V: ... + def __call__(self) -> V: ... class WritableManagedValue(Generic[V, U], ManagedValue[V], ABC): diff --git a/libs/langgraph/langgraph/managed/context.py b/libs/langgraph/langgraph/managed/context.py index df64419ec..d1713c11a 100644 --- a/libs/langgraph/langgraph/managed/context.py +++ b/libs/langgraph/langgraph/managed/context.py @@ -13,10 +13,10 @@ Union, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.managed.base import ConfiguredManagedValue, ManagedValue, V +from langgraph.types import LoopProtocol class Context(ManagedValue[V], Generic[V]): @@ -46,14 +46,14 @@ def of( @classmethod @contextmanager - def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: - with super().enter(config, **kwargs) as self: + def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: + with super().enter(loop, **kwargs) as self: if self.ctx is None: raise ValueError( "Synchronous context manager not found. Please initialize Context value with a sync context manager, or invoke your graph asynchronously." ) ctx = ( - self.ctx(config) # type: ignore[call-arg] + self.ctx(loop.config) # type: ignore[call-arg] if signature(self.ctx).parameters.get("config") else self.ctx() ) @@ -63,17 +63,17 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: @classmethod @asynccontextmanager - async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]: - async with super().aenter(config, **kwargs) as self: + async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: + async with super().aenter(loop, **kwargs) as self: if self.actx is not None: ctx = ( - self.actx(config) # type: ignore[call-arg] + self.actx(loop.config) # type: ignore[call-arg] if signature(self.actx).parameters.get("config") else self.actx() ) elif self.ctx is not None: ctx = ( - self.ctx(config) # type: ignore + self.ctx(loop.config) # type: ignore if signature(self.ctx).parameters.get("config") else self.ctx() ) @@ -96,7 +96,7 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se def __init__( self, - config: RunnableConfig, + loop: LoopProtocol, *, ctx: Union[None, Type[ContextManager[V]], Type[AsyncContextManager[V]]] = None, actx: Optional[Type[AsyncContextManager[V]]] = None, @@ -104,5 +104,5 @@ def __init__( self.ctx = ctx self.actx = actx - def __call__(self, step: int) -> V: + def __call__(self) -> V: return self.value diff --git a/libs/langgraph/langgraph/managed/is_last_step.py b/libs/langgraph/langgraph/managed/is_last_step.py index d8b4f9102..a4ac277d3 100644 --- a/libs/langgraph/langgraph/managed/is_last_step.py +++ b/libs/langgraph/langgraph/managed/is_last_step.py @@ -4,8 +4,16 @@ class IsLastStepManager(ManagedValue[bool]): - def __call__(self, step: int) -> bool: - return step == self.config.get("recursion_limit", 0) - 1 + def __call__(self) -> bool: + return self.loop.step == self.loop.stop IsLastStep = Annotated[bool, IsLastStepManager] + + +class RemainingStepsManager(ManagedValue[bool]): + def __call__(self) -> bool: + return self.loop.stop - self.loop.step + + +RemainingSteps = Annotated[bool, RemainingStepsManager] diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index 7c8f45b14..300d36c7d 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -7,13 +7,11 @@ Optional, Sequence, Type, - cast, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import NotRequired, Required, Self -from langgraph.constants import CONF, CONFIG_KEY_STORE +from langgraph.constants import CONF from langgraph.errors import InvalidUpdateError from langgraph.managed.base import ( ChannelKeyPlaceholder, @@ -21,7 +19,8 @@ ConfiguredManagedValue, WritableManagedValue, ) -from langgraph.store.base import BaseStore, PutOp +from langgraph.store.base import PutOp +from langgraph.types import LoopProtocol V = dict[str, Any] @@ -55,25 +54,26 @@ def on(scope: str) -> ConfiguredManagedValue: @classmethod @contextmanager - def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: - with super().enter(config, **kwargs) as value: - if value.store is not None: - saved = value.store.search(value.ns) + def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: + with super().enter(loop, **kwargs) as value: + if loop.store is not None: + saved = loop.store.search(value.ns) value.value = {it.key: it.value for it in saved} yield value @classmethod @asynccontextmanager - async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]: - async with super().aenter(config, **kwargs) as value: - if value.store is not None: - saved = await value.store.asearch(value.ns) + async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: + async with super().aenter(loop, **kwargs) as value: + if loop.store is not None: + saved = await loop.store.asearch(value.ns) value.value = {it.key: it.value for it in saved} yield value def __init__( - self, config: RunnableConfig, *, typ: Type[Any], scope: str, key: str + self, loop: LoopProtocol, *, typ: Type[Any], scope: str, key: str ) -> None: + super().__init__(loop) if typ := _strip_extras(typ): if typ not in ( dict, @@ -83,18 +83,17 @@ def __init__( raise ValueError("SharedValue must be a dict") self.scope = scope self.value: Value = {} - self.store = cast(BaseStore, config[CONF].get(CONFIG_KEY_STORE)) - if self.store is None: + if self.loop.store is None: pass - elif scope_value := config[CONF].get(self.scope): + elif scope_value := self.loop.config[CONF].get(self.scope): self.ns = ("scoped", scope, key, scope_value) else: raise ValueError( f"Scope {scope} for shared state key not in config.configurable" ) - def __call__(self, step: int) -> Value: - return self.value.copy() + def __call__(self) -> Value: + return self.value def _process_update(self, values: Sequence[Update]) -> list[PutOp]: writes: list[PutOp] = [] @@ -112,13 +111,13 @@ def _process_update(self, values: Sequence[Update]) -> list[PutOp]: return writes def update(self, values: Sequence[Update]) -> None: - if self.store is None: + if self.loop.store is None: self._process_update(values) else: - return self.store.batch(self._process_update(values)) + return self.loop.store.batch(self._process_update(values)) async def aupdate(self, writes: Sequence[Update]) -> None: - if self.store is None: + if self.loop.store is None: self._process_update(writes) else: - return await self.store.abatch(self._process_update(writes)) + return await self.loop.store.abatch(self._process_update(writes)) diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index 898460176..5f739c0a6 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -14,7 +14,7 @@ from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import add_messages -from langgraph.managed import IsLastStep +from langgraph.managed import IsLastStep, RemainingSteps from langgraph.prebuilt.tool_executor import ToolExecutor from langgraph.prebuilt.tool_node import ToolNode from langgraph.store.base import BaseStore @@ -33,6 +33,8 @@ class AgentState(TypedDict): is_last_step: IsLastStep + remaining_steps: RemainingSteps + StateSchema = TypeVar("StateSchema", bound=AgentState) StateSchemaType = Type[StateSchema] @@ -529,10 +531,28 @@ def should_continue(state: AgentState) -> Literal["tools", "__end__"]: # Define the function that calls the model def call_model(state: AgentState, config: RunnableConfig) -> AgentState: response = model_runnable.invoke(state, config) + has_tool_calls = isinstance(response, AIMessage) and response.tool_calls + all_tools_return_direct = ( + all(call["name"] in should_return_direct for call in response.tool_calls) + if isinstance(response, AIMessage) + else False + ) if ( - state["is_last_step"] - and isinstance(response, AIMessage) - and response.tool_calls + ( + "remaining_steps" not in state + and state["is_last_step"] + and has_tool_calls + ) + or ( + "remaining_steps" in state + and state["remaining_steps"] < 1 + and all_tools_return_direct + ) + or ( + "remaining_steps" in state + and state["remaining_steps"] < 2 + and has_tool_calls + ) ): return { "messages": [ @@ -547,10 +567,28 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState: async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: response = await model_runnable.ainvoke(state, config) + has_tool_calls = isinstance(response, AIMessage) and response.tool_calls + all_tools_return_direct = ( + all(call["name"] in should_return_direct for call in response.tool_calls) + if isinstance(response, AIMessage) + else False + ) if ( - state["is_last_step"] - and isinstance(response, AIMessage) - and response.tool_calls + ( + "remaining_steps" not in state + and state["is_last_step"] + and has_tool_calls + ) + or ( + "remaining_steps" in state + and state["remaining_steps"] < 1 + and all_tools_return_direct + ) + or ( + "remaining_steps" in state + and state["remaining_steps"] < 2 + and has_tool_calls + ) ): return { "messages": [ diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 5436d426d..04e302232 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -88,7 +88,7 @@ from langgraph.pregel.validate import validate_graph, validate_keys from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore -from langgraph.types import All, Checkpointer, StateSnapshot, StreamMode +from langgraph.types import All, Checkpointer, LoopProtocol, StateSnapshot, StreamMode from langgraph.utils.config import ( ensure_config, merge_configs, @@ -433,7 +433,14 @@ def _prepare_state_snapshot( ) with ChannelsManager( - self.channels, saved.checkpoint, saved.config, skip_context=True + self.channels, + saved.checkpoint, + LoopProtocol( + config=saved.config, + step=saved.metadata.get("step", -1) + 1, + stop=saved.metadata.get("step", -1) + 2, + ), + skip_context=True, ) as (channels, managed): # tasks for this checkpoint next_tasks = prepare_next_tasks( @@ -511,7 +518,14 @@ async def _aprepare_state_snapshot( ) async with AsyncChannelsManager( - self.channels, saved.checkpoint, saved.config, skip_context=True + self.channels, + saved.checkpoint, + LoopProtocol( + config=saved.config, + step=saved.metadata.get("step", -1) + 1, + stop=saved.metadata.get("step", -1) + 2, + ), + skip_context=True, ) as ( channels, managed, @@ -835,7 +849,11 @@ def update_state( if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") # update channels - with ChannelsManager(self.channels, checkpoint, config) as ( + with ChannelsManager( + self.channels, + checkpoint, + LoopProtocol(config=config, step=step + 1, stop=step + 2), + ) as ( channels, managed, ): @@ -981,7 +999,11 @@ async def aupdate_state( if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") # update channels, acting as the chosen node - async with AsyncChannelsManager(self.channels, checkpoint, config) as ( + async with AsyncChannelsManager( + self.channels, + checkpoint, + LoopProtocol(config=config, step=step + 1, stop=step + 2), + ) as ( channels, managed, ): diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 98ac8576f..4a2455aa1 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -57,7 +57,7 @@ from langgraph.pregel.manager import ChannelsManager from langgraph.pregel.read import PregelNode from langgraph.store.base import BaseStore -from langgraph.types import All, PregelExecutableTask, PregelTask +from langgraph.types import All, LoopProtocol, PregelExecutableTask, PregelTask from langgraph.utils.config import merge_configs, patch_config GetNextVersion = Callable[[Optional[V], BaseChannel], V] @@ -148,7 +148,7 @@ def local_read( with ChannelsManager( {k: v for k, v in channels.items() if k in updated}, checkpoint, - config, + LoopProtocol(config=config, step=step, stop=step + 1), skip_context=True, ) as (local_channels, _): apply_writes(copy_checkpoint(checkpoint), local_channels, [task], None) @@ -156,7 +156,7 @@ def local_read( else: values = read_channels(channels, select) if managed_keys: - values.update({k: managed[k](step) for k in managed_keys}) + values.update({k: managed[k]() for k in managed_keys}) return values @@ -493,9 +493,7 @@ def prepare_single_task( ): try: val = next( - _proc_input( - step, proc, managed, channels, for_execution=for_execution - ) + _proc_input(proc, managed, channels, for_execution=for_execution) ) except StopIteration: return @@ -583,7 +581,6 @@ def prepare_single_task( def _proc_input( - step: int, proc: PregelNode, managed: ManagedValueMapping, channels: Mapping[str, BaseChannel], @@ -605,7 +602,7 @@ def _proc_input( except EmptyChannelError: continue else: - val[k] = managed[k](step) + val[k] = managed[k]() except EmptyChannelError: return elif isinstance(proc.channels, list): diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index d4f3a52c3..f25fa2300 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -100,7 +100,7 @@ from langgraph.pregel.read import PregelNode from langgraph.pregel.utils import get_new_channel_versions from langgraph.store.base import BaseStore -from langgraph.types import All, PregelExecutableTask, StreamMode +from langgraph.types import All, LoopProtocol, PregelExecutableTask, StreamProtocol from langgraph.utils.config import patch_configurable V = TypeVar("V") @@ -112,22 +112,6 @@ SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) -class StreamProtocol: - __slots__ = ("modes", "__call__") - - modes: set[StreamMode] - - __call__: Callable[[StreamChunk], None] - - def __init__( - self, - __call__: Callable[[StreamChunk], None], - modes: set[StreamMode], - ) -> None: - self.__call__ = __call__ - self.modes = modes - - def DuplexStream(*streams: StreamProtocol) -> StreamProtocol: def __call__(value: StreamChunk) -> None: for stream in streams: @@ -137,16 +121,13 @@ def __call__(value: StreamChunk) -> None: return StreamProtocol(__call__, {mode for s in streams for mode in s.modes}) -class PregelLoop: +class PregelLoop(LoopProtocol): input: Optional[Any] - config: RunnableConfig - store: Optional[BaseStore] checkpointer: Optional[BaseCheckpointSaver] nodes: Mapping[str, PregelNode] specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]] output_keys: Union[str, Sequence[str]] stream_keys: Union[str, Sequence[str]] - stream: Optional[StreamProtocol] skip_done_tasks: bool is_nested: bool @@ -177,8 +158,6 @@ class PregelLoop: checkpoint_previous_versions: dict[str, Union[str, float, int]] prev_checkpoint_config: Optional[RunnableConfig] - step: int - stop: int status: Literal[ "pending", "done", "interrupt_before", "interrupt_after", "out_of_steps" ] @@ -202,10 +181,14 @@ def __init__( check_subgraphs: bool = True, debug: bool = False, ) -> None: - self.stream = stream + super().__init__( + step=0, + stop=0, + config=config, + stream=stream, + store=store, + ) self.input = input - self.config = config - self.store = store self.checkpointer = checkpointer self.nodes = nodes self.specs = specs @@ -730,7 +713,7 @@ def __enter__(self) -> Self: self.submit = self.stack.enter_context(BackgroundExecutor(self.config)) self.channels, self.managed = self.stack.enter_context( - ChannelsManager(self.specs, self.checkpoint, self.config, self.store) + ChannelsManager(self.specs, self.checkpoint, self) ) self.stack.push(self._suppress_interrupt) self.status = "pending" @@ -858,7 +841,7 @@ async def __aenter__(self) -> Self: self.submit = await self.stack.enter_async_context(AsyncBackgroundExecutor()) self.channels, self.managed = await self.stack.enter_async_context( - AsyncChannelsManager(self.specs, self.checkpoint, self.config, self.store) + AsyncChannelsManager(self.specs, self.checkpoint, self) ) self.stack.push(self._suppress_interrupt) self.status = "pending" diff --git a/libs/langgraph/langgraph/pregel/manager.py b/libs/langgraph/langgraph/pregel/manager.py index c6d6c07aa..641e1d8fe 100644 --- a/libs/langgraph/langgraph/pregel/manager.py +++ b/libs/langgraph/langgraph/pregel/manager.py @@ -1,33 +1,27 @@ import asyncio from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager -from typing import AsyncIterator, Iterator, Mapping, Optional, Union - -from langchain_core.runnables import RunnableConfig +from typing import AsyncIterator, Iterator, Mapping, Union from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import Checkpoint -from langgraph.constants import CONFIG_KEY_STORE from langgraph.managed.base import ( ConfiguredManagedValue, ManagedValueMapping, ManagedValueSpec, ) from langgraph.managed.context import Context -from langgraph.store.base import BaseStore -from langgraph.utils.config import patch_configurable +from langgraph.types import LoopProtocol @contextmanager def ChannelsManager( specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], checkpoint: Checkpoint, - config: RunnableConfig, - store: Optional[BaseStore] = None, + loop: LoopProtocol, *, skip_context: bool = False, ) -> Iterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" - config_for_managed = patch_configurable(config, {CONFIG_KEY_STORE: store}) channel_specs: dict[str, BaseChannel] = {} managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): @@ -48,9 +42,9 @@ def ChannelsManager( ManagedValueMapping( { key: stack.enter_context( - value.cls.enter(config_for_managed, **value.kwargs) + value.cls.enter(loop, **value.kwargs) if isinstance(value, ConfiguredManagedValue) - else value.enter(config_for_managed) + else value.enter(loop) ) for key, value in managed_specs.items() } @@ -62,13 +56,11 @@ def ChannelsManager( async def AsyncChannelsManager( specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], checkpoint: Checkpoint, - config: RunnableConfig, - store: Optional[BaseStore] = None, + loop: LoopProtocol, *, skip_context: bool = False, ) -> AsyncIterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" - config_for_managed = patch_configurable(config, {CONFIG_KEY_STORE: store}) channel_specs: dict[str, BaseChannel] = {} managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): @@ -85,9 +77,9 @@ async def AsyncChannelsManager( if tasks := { asyncio.create_task( stack.enter_async_context( - value.cls.aenter(config_for_managed, **value.kwargs) + value.cls.aenter(loop, **value.kwargs) if isinstance(value, ConfiguredManagedValue) - else value.aenter(config_for_managed) + else value.aenter(loop) ) ): key for key, value in managed_specs.items() diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index 42bb149db..c29cc1480 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -1,6 +1,7 @@ from collections import deque from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Literal, @@ -15,6 +16,9 @@ from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointMetadata +if TYPE_CHECKING: + from langgraph.store.base import BaseStore + All = Literal["*"] """Special value to indicate that graph should interrupt on all nodes.""" @@ -213,3 +217,45 @@ def __eq__(self, value: object) -> bool: and self.node == value.node and self.arg == value.arg ) + + +StreamChunk = tuple[tuple[str, ...], str, Any] + + +class StreamProtocol: + __slots__ = ("modes", "__call__") + + modes: set[StreamMode] + + __call__: Callable[[StreamChunk], None] + + def __init__( + self, + __call__: Callable[[StreamChunk], None], + modes: set[StreamMode], + ) -> None: + self.__call__ = __call__ + self.modes = modes + + +class LoopProtocol: + config: RunnableConfig + store: Optional["BaseStore"] + stream: Optional[StreamProtocol] + step: int + stop: int + + def __init__( + self, + *, + step: int, + stop: int, + config: RunnableConfig, + store: Optional["BaseStore"] = None, + stream: Optional[StreamProtocol] = None, + ) -> None: + self.stream = stream + self.config = config + self.store = store + self.step = step + self.stop = stop diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py index c2ad752e4..a7c99900d 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py @@ -37,7 +37,7 @@ Sendable, Topics, ) -from langgraph.types import RetryPolicy +from langgraph.types import LoopProtocol, RetryPolicy from langgraph.utils.config import patch_configurable @@ -183,7 +183,14 @@ async def attempt(self, msg: MessageToExecutor) -> None: if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]: raise CheckpointNotLatest() async with AsyncChannelsManager( - graph.channels, saved.checkpoint, msg["config"], self.graph.store + graph.channels, + saved.checkpoint, + LoopProtocol( + config=msg["config"], + store=self.graph.store, + step=saved.metadata["step"] + 1, + stop=saved.metadata["step"] + 2, + ), ) as (channels, managed), AsyncBackgroundExecutor() as submit: if task := await asyncio.to_thread( prepare_single_task, @@ -379,7 +386,14 @@ def attempt(self, msg: MessageToExecutor) -> None: if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]: raise CheckpointNotLatest() with ChannelsManager( - graph.channels, saved.checkpoint, msg["config"], self.graph.store + graph.channels, + saved.checkpoint, + LoopProtocol( + config=msg["config"], + store=self.graph.store, + step=saved.metadata["step"] + 1, + stop=saved.metadata["step"] + 2, + ), ) as (channels, managed), BackgroundExecutor({}) as submit: if task := prepare_single_task( msg["task"]["path"],