Skip to content

Commit

Permalink
Merge pull request #2105 from langchain-ai/nc/14oct/is-last-step-fix
Browse files Browse the repository at this point in the history
Fix IsLastStep counter for runs with checkpointers
  • Loading branch information
nfcampos authored Oct 15, 2024
2 parents edec5c0 + d48faec commit e8b8759
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 111 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/managed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from langgraph.managed.is_last_step import IsLastStep
from langgraph.managed.is_last_step import IsLastStep, RemainingSteps

__all__ = ["IsLastStep"]
__all__ = ["IsLastStep", "RemainingSteps"]
17 changes: 9 additions & 8 deletions libs/langgraph/langgraph/managed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
Union,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self, TypeGuard

from langgraph.types import LoopProtocol

V = TypeVar("V")
U = TypeVar("U")


class ManagedValue(ABC, Generic[V]):
def __init__(self, config: RunnableConfig) -> None:
self.config = config
def __init__(self, loop: LoopProtocol) -> None:
self.loop = loop

@classmethod
@contextmanager
def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]:
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
try:
value = cls(config, **kwargs)
value = cls(loop, **kwargs)
yield value
finally:
# because managed value and Pregel have reference to each other
Expand All @@ -40,9 +41,9 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]:

@classmethod
@asynccontextmanager
async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]:
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
try:
value = cls(config, **kwargs)
value = cls(loop, **kwargs)
yield value
finally:
# because managed value and Pregel have reference to each other
Expand All @@ -53,7 +54,7 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se
pass

@abstractmethod
def __call__(self, step: int) -> V: ...
def __call__(self) -> V: ...


class WritableManagedValue(Generic[V, U], ManagedValue[V], ABC):
Expand Down
20 changes: 10 additions & 10 deletions libs/langgraph/langgraph/managed/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
Union,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.managed.base import ConfiguredManagedValue, ManagedValue, V
from langgraph.types import LoopProtocol


class Context(ManagedValue[V], Generic[V]):
Expand Down Expand Up @@ -46,14 +46,14 @@ def of(

@classmethod
@contextmanager
def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]:
with super().enter(config, **kwargs) as self:
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
with super().enter(loop, **kwargs) as self:
if self.ctx is None:
raise ValueError(
"Synchronous context manager not found. Please initialize Context value with a sync context manager, or invoke your graph asynchronously."
)
ctx = (
self.ctx(config) # type: ignore[call-arg]
self.ctx(loop.config) # type: ignore[call-arg]
if signature(self.ctx).parameters.get("config")
else self.ctx()
)
Expand All @@ -63,17 +63,17 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]:

@classmethod
@asynccontextmanager
async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(config, **kwargs) as self:
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(loop, **kwargs) as self:
if self.actx is not None:
ctx = (
self.actx(config) # type: ignore[call-arg]
self.actx(loop.config) # type: ignore[call-arg]
if signature(self.actx).parameters.get("config")
else self.actx()
)
elif self.ctx is not None:
ctx = (
self.ctx(config) # type: ignore
self.ctx(loop.config) # type: ignore
if signature(self.ctx).parameters.get("config")
else self.ctx()
)
Expand All @@ -96,13 +96,13 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se

def __init__(
self,
config: RunnableConfig,
loop: LoopProtocol,
*,
ctx: Union[None, Type[ContextManager[V]], Type[AsyncContextManager[V]]] = None,
actx: Optional[Type[AsyncContextManager[V]]] = None,
) -> None:
self.ctx = ctx
self.actx = actx

def __call__(self, step: int) -> V:
def __call__(self) -> V:
return self.value
12 changes: 10 additions & 2 deletions libs/langgraph/langgraph/managed/is_last_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@


class IsLastStepManager(ManagedValue[bool]):
def __call__(self, step: int) -> bool:
return step == self.config.get("recursion_limit", 0) - 1
def __call__(self) -> bool:
return self.loop.step == self.loop.stop


IsLastStep = Annotated[bool, IsLastStepManager]


class RemainingStepsManager(ManagedValue[bool]):
def __call__(self) -> bool:
return self.loop.stop - self.loop.step


RemainingSteps = Annotated[bool, RemainingStepsManager]
43 changes: 21 additions & 22 deletions libs/langgraph/langgraph/managed/shared_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
Optional,
Sequence,
Type,
cast,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import NotRequired, Required, Self

from langgraph.constants import CONF, CONFIG_KEY_STORE
from langgraph.constants import CONF
from langgraph.errors import InvalidUpdateError
from langgraph.managed.base import (
ChannelKeyPlaceholder,
ChannelTypePlaceholder,
ConfiguredManagedValue,
WritableManagedValue,
)
from langgraph.store.base import BaseStore, PutOp
from langgraph.store.base import PutOp
from langgraph.types import LoopProtocol

V = dict[str, Any]

Expand Down Expand Up @@ -55,25 +54,26 @@ def on(scope: str) -> ConfiguredManagedValue:

@classmethod
@contextmanager
def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]:
with super().enter(config, **kwargs) as value:
if value.store is not None:
saved = value.store.search(value.ns)
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
with super().enter(loop, **kwargs) as value:
if loop.store is not None:
saved = loop.store.search(value.ns)
value.value = {it.key: it.value for it in saved}
yield value

@classmethod
@asynccontextmanager
async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(config, **kwargs) as value:
if value.store is not None:
saved = await value.store.asearch(value.ns)
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(loop, **kwargs) as value:
if loop.store is not None:
saved = await loop.store.asearch(value.ns)
value.value = {it.key: it.value for it in saved}
yield value

def __init__(
self, config: RunnableConfig, *, typ: Type[Any], scope: str, key: str
self, loop: LoopProtocol, *, typ: Type[Any], scope: str, key: str
) -> None:
super().__init__(loop)
if typ := _strip_extras(typ):
if typ not in (
dict,
Expand All @@ -83,18 +83,17 @@ def __init__(
raise ValueError("SharedValue must be a dict")
self.scope = scope
self.value: Value = {}
self.store = cast(BaseStore, config[CONF].get(CONFIG_KEY_STORE))
if self.store is None:
if self.loop.store is None:
pass
elif scope_value := config[CONF].get(self.scope):
elif scope_value := self.loop.config[CONF].get(self.scope):
self.ns = ("scoped", scope, key, scope_value)
else:
raise ValueError(
f"Scope {scope} for shared state key not in config.configurable"
)

def __call__(self, step: int) -> Value:
return self.value.copy()
def __call__(self) -> Value:
return self.value

def _process_update(self, values: Sequence[Update]) -> list[PutOp]:
writes: list[PutOp] = []
Expand All @@ -112,13 +111,13 @@ def _process_update(self, values: Sequence[Update]) -> list[PutOp]:
return writes

def update(self, values: Sequence[Update]) -> None:
if self.store is None:
if self.loop.store is None:
self._process_update(values)
else:
return self.store.batch(self._process_update(values))
return self.loop.store.batch(self._process_update(values))

async def aupdate(self, writes: Sequence[Update]) -> None:
if self.store is None:
if self.loop.store is None:
self._process_update(writes)
else:
return await self.store.abatch(self._process_update(writes))
return await self.loop.store.abatch(self._process_update(writes))
52 changes: 45 additions & 7 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
from langgraph.managed import IsLastStep, RemainingSteps
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.store.base import BaseStore
Expand All @@ -33,6 +33,8 @@ class AgentState(TypedDict):

is_last_step: IsLastStep

remaining_steps: RemainingSteps


StateSchema = TypeVar("StateSchema", bound=AgentState)
StateSchemaType = Type[StateSchema]
Expand Down Expand Up @@ -529,10 +531,28 @@ def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
# Define the function that calls the model
def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
response = model_runnable.invoke(state, config)
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
if isinstance(response, AIMessage)
else False
)
if (
state["is_last_step"]
and isinstance(response, AIMessage)
and response.tool_calls
(
"remaining_steps" not in state
and state["is_last_step"]
and has_tool_calls
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 1
and all_tools_return_direct
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 2
and has_tool_calls
)
):
return {
"messages": [
Expand All @@ -547,10 +567,28 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState:

async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
response = await model_runnable.ainvoke(state, config)
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
if isinstance(response, AIMessage)
else False
)
if (
state["is_last_step"]
and isinstance(response, AIMessage)
and response.tool_calls
(
"remaining_steps" not in state
and state["is_last_step"]
and has_tool_calls
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 1
and all_tools_return_direct
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 2
and has_tool_calls
)
):
return {
"messages": [
Expand Down
32 changes: 27 additions & 5 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, Checkpointer, StateSnapshot, StreamMode
from langgraph.types import All, Checkpointer, LoopProtocol, StateSnapshot, StreamMode
from langgraph.utils.config import (
ensure_config,
merge_configs,
Expand Down Expand Up @@ -433,7 +433,14 @@ def _prepare_state_snapshot(
)

with ChannelsManager(
self.channels, saved.checkpoint, saved.config, skip_context=True
self.channels,
saved.checkpoint,
LoopProtocol(
config=saved.config,
step=saved.metadata.get("step", -1) + 1,
stop=saved.metadata.get("step", -1) + 2,
),
skip_context=True,
) as (channels, managed):
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
Expand Down Expand Up @@ -511,7 +518,14 @@ async def _aprepare_state_snapshot(
)

async with AsyncChannelsManager(
self.channels, saved.checkpoint, saved.config, skip_context=True
self.channels,
saved.checkpoint,
LoopProtocol(
config=saved.config,
step=saved.metadata.get("step", -1) + 1,
stop=saved.metadata.get("step", -1) + 2,
),
skip_context=True,
) as (
channels,
managed,
Expand Down Expand Up @@ -835,7 +849,11 @@ def update_state(
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
# update channels
with ChannelsManager(self.channels, checkpoint, config) as (
with ChannelsManager(
self.channels,
checkpoint,
LoopProtocol(config=config, step=step + 1, stop=step + 2),
) as (
channels,
managed,
):
Expand Down Expand Up @@ -981,7 +999,11 @@ async def aupdate_state(
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
# update channels, acting as the chosen node
async with AsyncChannelsManager(self.channels, checkpoint, config) as (
async with AsyncChannelsManager(
self.channels,
checkpoint,
LoopProtocol(config=config, step=step + 1, stop=step + 2),
) as (
channels,
managed,
):
Expand Down
Loading

0 comments on commit e8b8759

Please sign in to comment.