Skip to content

Commit

Permalink
lib: Make Control object serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 5, 2024
1 parent 18a3fa4 commit 010564c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
17 changes: 16 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from zoneinfo import ZoneInfo

from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import SendProtocol
from langgraph.checkpoint.serde.types import ControlProtocol, SendProtocol
from langgraph.store.base import Item

LC_REVIVER = Reviver()
Expand Down Expand Up @@ -402,6 +402,21 @@ def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]:
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif isinstance(obj, ControlProtocol):
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
"update_state": obj.update_state,
"trigger": obj.trigger,
"send": obj.send,
},
),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return msgpack.ExtType(
Expand Down
11 changes: 11 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Protocol,
Sequence,
TypeVar,
Union,
runtime_checkable,
)

Expand Down Expand Up @@ -48,3 +49,13 @@ def __hash__(self) -> int: ...
def __repr__(self) -> str: ...

def __eq__(self, value: object) -> bool: ...


@runtime_checkable
class ControlProtocol(Protocol):
# Mirrors langgraph.constants.Control
update_state: Optional[dict[str, Any]]
trigger: Union[str, Sequence[str]]
send: Union[Any, Sequence[Any]]

def __repr__(self) -> str: ...
8 changes: 7 additions & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def __eq__(self, value: object) -> bool:


class Control(Generic[N]):
"""A control object to update the graph's state, trigger nodes, and send messages."""

__slots__ = ("update_state", "trigger", "send")

def __init__(
self,
*,
Expand All @@ -240,7 +244,9 @@ def __init__(

def __repr__(self) -> str:
contents = ", ".join(
f"{key}={value!r}" for key, value in self.__dict__.items() if value
f"{key}={value!r}"
for key in self.__slots__
if (value := getattr(self, key))
)
return f"Control({contents})"

Expand Down

0 comments on commit 010564c

Please sign in to comment.