Skip to content

Commit

Permalink
Merge pull request #305 from andrewnguonly/migrate-state-graph
Browse files Browse the repository at this point in the history
Update chatbot and RAG assistant to use `StateGraph` in the backend
  • Loading branch information
nfcampos authored Apr 18, 2024
2 parents b925344 + deab0ed commit f0c25df
Show file tree
Hide file tree
Showing 18 changed files with 315 additions and 119 deletions.
31 changes: 30 additions & 1 deletion API.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ This creates an assistant with the name `"bar"`, with GPT 3.5 Turbo, with a prom
Available tools names can be found in the AvailableTools class in backend/packages/gizmo-agent/gizmo_agent/tools.py
Available llms can be found in GizmoAgentType in backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py

Note: If a RAGBot assistant is created (`type` equals `chat_retrieval`), then subsequent API requests/responses for the threads APIs are slightly modified and noted below.

## Create a thread

We can now create a thread.
Expand Down Expand Up @@ -85,7 +87,11 @@ requests.get(
).content
```
```shell
b'{"messages":[]}'
b'{"values":[]}'
```
For RAGBot:
```shell
b'{"values":{"messages":[]}}'
```

Let's add a message to the thread!
Expand All @@ -102,6 +108,17 @@ requests.post(
}
).content
```
For RAGBot:
```
{
"values": {
"messages": [{
"content": "hi! my name is bob",
"type": "human",
}]
}
}
```

If we now run the command to see the thread, we can see that there is now a message on that thread

Expand All @@ -115,6 +132,10 @@ requests.get(
```shell
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false}],"next":[]}'
```
For RAGBot:
```shell
b'{"values":{"messages":[...]},"next":[]}'
```

## Run the assistant on that thread

Expand All @@ -141,6 +162,10 @@ requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b
```shell
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false}],"next":[]}'
```
For RAGBot:
```shell
b'{"values":{"messages":[...]},"next":[]}'
```

## Run the assistant on the thread with new messages

Expand Down Expand Up @@ -171,6 +196,10 @@ requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b
```shell
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false},{"content":"whats my name? respond in spanish","additional_kwargs":{},"type":"human","example":false},{"content":"Tu nombre es Bob.","additional_kwargs":{"agent":{"return_values":{"output":"Tu nombre es Bob."},"log":"Tu nombre es Bob.","type":"AgentFinish"}},"type":"ai","example":false}],"next":[]}'
```
For RAGBot:
```shell
b'{"values":{"messages":[...]},"next":[]}'
```

## Stream
One thing we can do is stream back responses.
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ This can be a really powerful and flexible architecture. This is probably closes
these also can be not super reliable, and generally only work with the more performant models (and even then they can
mess up). Therefore, we introduced a few simpler architecures.
Assistants are implemented with [LangGraph](https://github.com/langchain-ai/langgraph) `MessageGraph`. A `MessageGraph` is a graph that models its state as a `list` of messages.
**RAGBot**
One of the big use cases of the GPT store is uploading files and giving the bot knowledge of those files. What would it
Expand All @@ -321,6 +323,8 @@ pretty well with a wider variety of models (including lots of open source models
you don’t NEED the flexibility of an assistant (eg you know users will be looking up information every time) then it
can be more focused. And third, compared to the final architecture below it can use external knowledge.
RAGBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph) `StateGraph`. A `StateGraph` is a generalized graph that can model arbitrary state (i.e. `dict`), not just a `list` of messages.
**ChatBot**
The final architecture is dead simple - just a call to a language model, parameterized by a system message. This allows
Expand All @@ -331,6 +335,8 @@ well.
![](_static/chatbot.png)
ChatBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph) `StateGraph`. A `StateGraph` is a generalized graph that can model arbitrary state (i.e. `dict`), not just a `list` of messages.
### LLMs
You can choose between different LLMs to use.
Expand Down
24 changes: 18 additions & 6 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pickle
from enum import Enum
from typing import Any, Mapping, Optional, Sequence, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import (
ConfigurableField,
RunnableBinding,
)
from langgraph.checkpoint import CheckpointAt
from langgraph.graph.message import Messages
from langgraph.pregel import Pregel

from app.agent_types.tools_agent import get_tools_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
Expand Down Expand Up @@ -70,7 +73,7 @@ class AgentType(str, Enum):

DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."

CHECKPOINTER = PostgresCheckpoint(at=CheckpointAt.END_OF_STEP)
CHECKPOINTER = PostgresCheckpoint(serde=pickle, at=CheckpointAt.END_OF_STEP)


def get_agent_executor(
Expand Down Expand Up @@ -244,7 +247,10 @@ def __init__(
llm=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="Instructions"),
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
.with_types(
input_type=Messages,
output_type=Sequence[AnyMessage],
)
)


Expand Down Expand Up @@ -306,11 +312,14 @@ def __init__(
),
thread_id=ConfigurableField(id="thread_id", name="Thread ID", is_shared=True),
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
.with_types(
input_type=Dict[str, Any],
output_type=Dict[str, Any],
)
)


agent = (
agent: Pregel = (
ConfigurableAgent(
agent=AgentType.GPT_35_TURBO,
tools=[],
Expand Down Expand Up @@ -343,7 +352,10 @@ def __init__(
chatbot=chatbot,
chat_retrieval=chat_retrieval,
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
.with_types(
input_type=Messages,
output_type=Sequence[AnyMessage],
)
)

if __name__ == "__main__":
Expand Down
14 changes: 5 additions & 9 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langserve.schema import FeedbackCreateRequest
from langserve.server import _unpack_input
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field
from sse_starlette import EventSourceResponse

from app.agent import agent
from app.auth.handlers import AuthedUser
from app.storage import get_assistant, get_thread
from app.stream import astream_messages, to_sse
from app.stream import astream_state, to_sse

router = APIRouter()

Expand Down Expand Up @@ -51,15 +50,12 @@ async def _run_input_and_config(payload: CreateRunPayload, user_id: str):
}

try:
input_ = (
_unpack_input(agent.get_input_schema(config).validate(payload.input))
if payload.input is not None
else None
)
if payload.input is not None:
agent.get_input_schema(config).validate(payload.input)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=payload)

return input_, config
return payload.input, config


@router.post("")
Expand All @@ -82,7 +78,7 @@ async def stream_run(
"""Create a run."""
input_, config = await _run_input_and_config(payload, user["user_id"])

return EventSourceResponse(to_sse(astream_messages(agent, input_, config)))
return EventSourceResponse(to_sse(astream_state(agent, input_, config)))


@router.get("/input_schema")
Expand Down
12 changes: 7 additions & 5 deletions backend/app/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Annotated, List

from app.message_types import add_messages_liberal
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.messages import BaseMessage, SystemMessage
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.graph.state import StateGraph


def get_chatbot_executor(
Expand All @@ -15,9 +17,9 @@ def _get_messages(messages):

chatbot = _get_messages | llm

workflow = MessageGraph()
workflow = StateGraph(Annotated[List[BaseMessage], add_messages_liberal])
workflow.add_node("chatbot", chatbot)
workflow.set_entry_point("chatbot")
workflow.add_edge("chatbot", END)
workflow.set_finish_point("chatbot")
app = workflow.compile(checkpointer=checkpoint)
return app
17 changes: 14 additions & 3 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointAt,
CheckpointThreadTs,
CheckpointTuple,
SerializerProtocol,
)

from app.lifespan import get_pg_pool

Expand All @@ -19,8 +25,13 @@ def loads(value: bytes) -> Checkpoint:


class PostgresCheckpoint(BaseCheckpointSaver):
class Config:
arbitrary_types_allowed = True
def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
at: Optional[CheckpointAt] = None,
) -> None:
super().__init__(serde=serde, at=at)

@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
Expand Down
36 changes: 34 additions & 2 deletions backend/app/message_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import Any
from typing import Any, get_args

from langchain_core.messages import FunctionMessage, ToolMessage
from langchain_core.messages import (
AnyMessage,
FunctionMessage,
MessageLikeRepresentation,
ToolMessage,
)
from langgraph.graph.message import add_messages, Messages


class LiberalFunctionMessage(FunctionMessage):
Expand All @@ -9,3 +15,29 @@ class LiberalFunctionMessage(FunctionMessage):

class LiberalToolMessage(ToolMessage):
content: Any


def _convert_pydantic_dict_to_message(
data: MessageLikeRepresentation
) -> MessageLikeRepresentation:
if (
isinstance(data, dict)
and "content" in data
and isinstance(data.get("type"), str)
):
for cls in get_args(AnyMessage):
if data["type"] == cls(content="").type:
return cls(**data)
return data


def add_messages_liberal(left: Messages, right: Messages):
# coerce to list
if not isinstance(left, list):
left = [left]
if not isinstance(right, list):
right = [right]
return add_messages(
[_convert_pydantic_dict_to_message(m) for m in left],
[_convert_pydantic_dict_to_message(m) for m in right],
)
Loading

0 comments on commit f0c25df

Please sign in to comment.