Skip to content

Commit

Permalink
Implement serialization with msgpack library (#1716)
Browse files Browse the repository at this point in the history
* Implement serialization with msgpack library

- encode custom python objects with a msgpack extension type, with constructor path string, and args encoded as nested msgpack doc

* Smaller msgpack extension types

* Update lock files

* lock

* Don't delegate to pydantic json

* Fix kafka serde

- should use our serializer to load, as inputs to subgraphs are serialized using it
  • Loading branch information
nfcampos authored Sep 16, 2024
1 parent afc64c4 commit 3b05279
Show file tree
Hide file tree
Showing 12 changed files with 895 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,7 @@ def _load_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata))

def _dump_metadata(self, metadata) -> str:
serialized_metadata_type, serialized_metadata = self.jsonplus_serde.dumps_typed(
metadata
)
if serialized_metadata_type != "json":
raise TypeError(
f"Failed to properly serialize metadata -- expected 'json', got '{serialized_metadata_type}'"
)
serialized_metadata = self.jsonplus_serde.dumps(metadata)
return serialized_metadata.decode()

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
Expand Down
158 changes: 146 additions & 12 deletions libs/checkpoint-postgres/poetry.lock

Large diffs are not rendered by default.

191 changes: 179 additions & 12 deletions libs/checkpoint-sqlite/poetry.lock

Large diffs are not rendered by default.

276 changes: 275 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Optional, Sequence
from uuid import UUID

import msgpack
from langchain_core.load.load import Reviver
from langchain_core.load.serializable import Serializable
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -184,7 +185,10 @@ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
elif isinstance(obj, bytearray):
return "bytearray", obj
else:
return "json", self.dumps(obj)
try:
return "msgpack", _msgpack_enc(obj)
except UnicodeEncodeError:
return "json", self.dumps(obj)

def loads(self, data: bytes) -> Any:
return json.loads(data, object_hook=self._reviver)
Expand All @@ -197,5 +201,275 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any:
return bytearray(data_)
elif type_ == "json":
return self.loads(data_)
elif type_ == "msgpack":
return msgpack.unpackb(data_, ext_hook=_msgpack_ext_hook)
else:
raise NotImplementedError(f"Unknown serialization type: {type_}")


# --- msgpack ---

EXT_CONSTRUCTOR_SINGLE_ARG = 0
EXT_CONSTRUCTOR_POS_ARGS = 1
EXT_CONSTRUCTOR_KW_ARGS = 2
EXT_METHOD_SINGLE_ARG = 3
EXT_PYDANTIC_V1 = 4
EXT_PYDANTIC_V2 = 5


def _msgpack_default(obj):
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
return msgpack.ExtType(
EXT_PYDANTIC_V2,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.model_dump(),
"model_validate_json",
),
),
)
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
return msgpack.ExtType(
EXT_PYDANTIC_V1,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.dict(),
),
),
)
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj._asdict(),
),
),
)
elif isinstance(obj, pathlib.Path):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
),
)
elif isinstance(obj, re.Pattern):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
("re", "compile", (obj.pattern, obj.flags)),
),
)
elif isinstance(obj, UUID):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
),
)
elif isinstance(obj, decimal.Decimal):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (set, frozenset, deque)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
),
)
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, datetime):
return msgpack.ExtType(
EXT_METHOD_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.isoformat(),
"fromisoformat",
),
),
)
elif isinstance(obj, timedelta):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.days, obj.seconds, obj.microseconds),
),
),
)
elif isinstance(obj, date):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.year, obj.month, obj.day),
),
),
)
elif isinstance(obj, time):
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
"hour": obj.hour,
"minute": obj.minute,
"second": obj.second,
"microsecond": obj.microsecond,
"tzinfo": obj.tzinfo,
"fold": obj.fold,
},
),
),
)
elif isinstance(obj, timezone):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.__getinitargs__(),
),
),
)
elif isinstance(obj, ZoneInfo):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
),
)
elif isinstance(obj, Enum):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
),
)
elif isinstance(obj, SendProtocol):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
field.name: getattr(obj, field.name)
for field in dataclasses.fields(obj)
},
),
),
)
elif isinstance(obj, BaseException):
return repr(obj)
else:
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")


def _msgpack_ext_hook(code: int, data: bytes):
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, arg
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
except Exception:
return
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
except Exception:
return
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
except Exception:
return
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, arg, method
return getattr(getattr(importlib.import_module(tup[0]), tup[1]), tup[3])(
tup[2]
)
except Exception:
return
elif code == EXT_PYDANTIC_V1:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, kwargs
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.construct(**tup[2])
except Exception:
return
elif code == EXT_PYDANTIC_V2:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, kwargs, method
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.model_construct(**tup[2])
except Exception:
return


ENC_POOL = deque(maxlen=32)


def _msgpack_enc(data: Any) -> bytes:
try:
enc = ENC_POOL.popleft()
except IndexError:
enc = msgpack.Packer(default=_msgpack_default)
try:
return enc.pack(data)
finally:
ENC_POOL.append(enc)
Loading

0 comments on commit 3b05279

Please sign in to comment.