Skip to content

Commit

Permalink
langgraph: changes for compatibility pydantic v2 / langchain-core==0.3 (
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Sep 5, 2024
1 parent 7461978 commit 7bf99a5
Show file tree
Hide file tree
Showing 18 changed files with 639 additions and 1,774 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ jobs:
- "3.10"
- "3.11"
- "3.12"
name: "test #${{ matrix.python-version }}"
core-version:
- ">=0.3.0.dev1,<0.4.0"
- "latest"

name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }})"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand All @@ -35,7 +39,11 @@ jobs:
- name: Install dependencies
shell: bash
working-directory: ${{ inputs.working-directory }}
run: poetry install --with dev
run: |
poetry install --with dev
if [ "${{ matrix.core-version }}" != "latest" ]; then
poetry run pip install "langchain-core${{ matrix.core-version }}"
fi
- name: Run core tests
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion examples/llm-compiler/math_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from langchain.chains.openai_functions import create_structured_output_runnable
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

_MATH_DESCRIPTION = (
"math(problem: str, context: Optional[list[str]]) -> float:\n"
Expand Down
11 changes: 6 additions & 5 deletions libs/checkpoint/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ packages = [{ include = "langgraph" }]

[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
langchain-core = ">=0.2.22,<0.3"
langchain-core = ">=0.2.38,<0.4"

[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
Expand Down
4 changes: 2 additions & 2 deletions libs/checkpoint/tests/test_jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ipaddress import IPv4Address

import dataclasses_json
from langchain_core.pydantic_v1 import BaseModel as LcBaseModel
from langchain_core.runnables import RunnableMap
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from zoneinfo import ZoneInfo

from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
Expand All @@ -23,7 +23,7 @@ class MyPydantic(BaseModel):
bar: int


class MyFunnyPydantic(LcBaseModel):
class MyFunnyPydantic(BaseModelV1):
foo: str
bar: int

Expand Down
2 changes: 1 addition & 1 deletion libs/cli/examples/graphs/storm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables import chain as as_runnable
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import END, StateGraph
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
Expand Down
11 changes: 4 additions & 7 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain_core.runnables.base import RunnableLike
from langchain_core.runnables.utils import create_model
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1

from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
Expand Down Expand Up @@ -473,10 +474,8 @@ class CompiledStateGraph(CompiledGraph):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
from pydantic import BaseModel as BaseModelP

if isclass(self.builder.input) and issubclass(
self.builder.input, (BaseModel, BaseModelP)
self.builder.input, (BaseModel, BaseModelV1)
):
return self.builder.input
else:
Expand Down Expand Up @@ -508,10 +507,8 @@ def get_input_schema(
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
from pydantic import BaseModel as BaseModelP

if isclass(self.builder.input) and issubclass(
self.builder.output, (BaseModel, BaseModelP)
if isclass(self.builder.output) and issubclass(
self.builder.output, (BaseModel, BaseModelV1)
):
return self.builder.output

Expand Down
29 changes: 20 additions & 9 deletions libs/langgraph/langgraph/prebuilt/tool_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@
ToolCall,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool, create_schema_from_function
from pydantic import BaseModel as BaseModelV2
from pydantic import ValidationError as ValidationErrorV2
from pydantic import BaseModel, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1

from langgraph.utils.runnable import RunnableCallable


def _default_format_error(
error: BaseException, call: ToolCall, schema: Type[BaseModel]
error: BaseException,
call: ToolCall,
schema: Union[Type[BaseModel], Type[BaseModelV1]],
) -> str:
"""Default error formatting function."""
return f"{repr(error)}\n\nRespond after fixing all validation errors."
Expand Down Expand Up @@ -75,7 +77,7 @@ class ValidationNode(RunnableCallable):
>>> from typing import Literal, Annotated, TypedDict
...
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.pydantic_v1 import BaseModel, validator
>>> from pydantic import BaseModel, validator
...
>>> from langgraph.graph import END, START, StateGraph
>>> from langgraph.prebuilt import ValidationNode
Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(
)
self.schemas_by_name[schema.name] = schema.args_schema
elif isinstance(schema, type) and issubclass(
schema, (BaseModel, BaseModelV2)
schema, (BaseModel, BaseModelV1)
):
self.schemas_by_name[schema.__name__] = cast(Type[BaseModel], schema)
elif callable(schema):
Expand Down Expand Up @@ -212,13 +214,22 @@ def _func(
def run_one(call: ToolCall):
schema = self.schemas_by_name[call["name"]]
try:
output = schema.validate(call["args"])
if issubclass(schema, BaseModel):
output = schema.model_validate(call["args"])
content = output.model_dump_json()
elif issubclass(schema, BaseModelV1):
output = schema.validate(call["args"])
content = output.json()
else:
raise ValueError(
f"Unsupported schema type: {type(schema)}. Expected BaseModel or BaseModelV1."
)
return ToolMessage(
content=output.json(),
content=content,
name=call["name"],
tool_call_id=cast(str, call["id"]),
)
except (ValidationError, ValidationErrorV2) as e:
except (ValidationError, ValidationErrorV1) as e:
return ToolMessage(
content=self._format_error(e, call, schema),
name=call["name"],
Expand Down
26 changes: 13 additions & 13 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ def __init__(
recurse: bool = True,
**kwargs: Any,
) -> None:
if name is not None:
self.name = name
elif func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.name = name
if self.name is None:
if func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.func = func
if func is not None:
self.func_accepts_config = accepts_config(func)
Expand Down
Loading

0 comments on commit 7bf99a5

Please sign in to comment.