From d0c86b62782a5e6f228dc80b2907d4cb519c2524 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 20 Nov 2023 12:54:51 +0000 Subject: [PATCH] Make messages optional --- backend/app/api/runs.py | 7 +++---- .../packages/agent-executor/agent_executor/permchain.py | 2 +- backend/packages/gizmo-agent/gizmo_agent/main.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index 78022fbf..a3ba09ba 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -17,7 +17,7 @@ from langserve.serialization import WellKnownLCSerializer from langserve.server import _get_base_run_id_as_str, _unpack_input from langsmith.utils import tracing_is_enabled -from pydantic import BaseModel +from pydantic import BaseModel, Field from sse_starlette import EventSourceResponse from app.schema import OpengptsUserId @@ -33,7 +33,7 @@ class AgentInput(BaseModel): """An input into an agent.""" - messages: Sequence[AnyMessage] + messages: Sequence[AnyMessage] = Field(default_factory=list) class CreateRunPayload(BaseModel): @@ -42,8 +42,7 @@ class CreateRunPayload(BaseModel): assistant_id: str thread_id: str stream: bool - # TODO make optional - input: AgentInput + input: AgentInput = Field(default_factory=AgentInput) @router.post("") diff --git a/backend/packages/agent-executor/agent_executor/permchain.py b/backend/packages/agent-executor/agent_executor/permchain.py index 4ff7c83a..1bee2808 100644 --- a/backend/packages/agent-executor/agent_executor/permchain.py +++ b/backend/packages/agent-executor/agent_executor/permchain.py @@ -113,7 +113,7 @@ def get_agent_executor( def route_last_message(input: dict[str, bool | Sequence[AnyMessage]]) -> Runnable: if not input["messages"]: # no messages, do nothing - return RunnablePassthrough() + return agent_chain message: AnyMessage = input["messages"][-1] if isinstance(message.additional_kwargs.get("agent"), AgentFinish): diff --git a/backend/packages/gizmo-agent/gizmo_agent/main.py b/backend/packages/gizmo-agent/gizmo_agent/main.py index 8fa93027..8ebc4c8f 100644 --- a/backend/packages/gizmo-agent/gizmo_agent/main.py +++ b/backend/packages/gizmo-agent/gizmo_agent/main.py @@ -75,7 +75,7 @@ def __init__( class AgentInput(BaseModel): - messages: Sequence[AnyMessage] + messages: Sequence[AnyMessage] = Field(default_factory=list) class AgentOutput(BaseModel):