Skip to content

Commit

Permalink
Add Docstrings (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Jul 18, 2024
1 parent 254926c commit 246b7da
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 6 deletions.
6 changes: 5 additions & 1 deletion docs/docs/reference/checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ You can [compile][langgraph.graph.MessageGraph.compile] any LangGraph workflow w

### Checkpoint

::: langgraph.checkpoint.Checkpoint
::: langgraph.checkpoint.base.Checkpoint

### CheckpointMetadata

::: langgraph.checkpoint.base.CheckpointMetadata

### BaseCheckpointSaver

Expand Down
139 changes: 134 additions & 5 deletions libs/langgraph/langgraph/checkpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ class CheckpointTuple(NamedTuple):


class BaseCheckpointSaver(ABC):
"""Base class for creating a graph checkpointer.
Checkpointers allow LangGraph agents to persist their state
within and across multiple interactions.
Attributes:
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
Note:
When creating a custom checkpoint saver, consider implementing async
versions to avoid blocking the main thread.
"""

serde: SerializerProtocol = JsonPlusSerializer()

def __init__(
Expand All @@ -152,13 +165,37 @@ def __init__(

@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Define the configuration options for the checkpoint saver.
Returns:
list[ConfigurableFieldSpec]: List of configuration field specs.
"""
return [CheckpointThreadId, CheckpointThreadTs]

def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""Fetch a checkpoint using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[Checkpoint]: The requested checkpoint, or None if not found.
"""
if value := self.get_tuple(config):
return value.checkpoint

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Fetch a checkpoint tuple using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError

def list(
Expand All @@ -169,6 +206,20 @@ def list(
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints that match the given criteria.
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.
Returns:
Iterator[CheckpointTuple]: Iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError

def put(
Expand All @@ -177,6 +228,19 @@ def put(
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
) -> RunnableConfig:
"""Store a checkpoint with its configuration and metadata.
Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError

def put_writes(
Expand All @@ -185,25 +249,60 @@ def put_writes(
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError(
"This method was added in langgraph 0.1.7. Please update your checkpointer to implement it."
"This method was added in langgraph 0.1.7. Please update your checkpoint saver to implement it."
)

async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""
Asynchronously fetch a checkpoint using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
"""
if value := await self.aget_tuple(config):
return value.checkpoint

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronously fetch a checkpoint tuple using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
"""
raise NotImplementedError

def alist(
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.
Returns:
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples.
"""
raise NotImplementedError
yield

Expand All @@ -213,6 +312,16 @@ async def aput(
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.
Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
raise NotImplementedError

async def aput_writes(
Expand All @@ -221,11 +330,31 @@ async def aput_writes(
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError(
"This method was added in langgraph 0.1.7. Please update your checkpointer to implement it."
"This method was added in langgraph 0.1.7. Please update your checkpoint saver to implement it."
)

def get_next_version(self, current: Optional[V], channel: BaseChannel) -> V:
"""Get the next version of a channel. Default is to use int versions, incrementing by 1. If you override, you can use str/int/float versions,
as long as they are monotonically increasing."""
"""Generate the next version ID for a channel.
Default is to use integer versions, incrementing by 1. If you override, you can use str/int/float versions,
as long as they are monotonically increasing.
Args:
current (Optional[V]): The current version identifier (int, float, or str).
channel (BaseChannel): The channel being versioned.
Returns:
V: The next version identifier, which must be increasing.
"""
return current + 1 if current is not None else 1

0 comments on commit 246b7da

Please sign in to comment.