Skip to content

Commit

Permalink
Refine the planner by include chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Sep 26, 2024
1 parent 66af273 commit 8696146
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 9 deletions.
31 changes: 28 additions & 3 deletions templates/types/multiagent/fastapi/app/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SubTask,
)
from llama_index.core.bridge.pydantic import ValidationError
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
Expand All @@ -24,6 +25,18 @@
step,
)

INITIAL_PLANNER_PROMPT = """\
Think step-by-step. Given a conversation, set of tools and a user request. Your responsibility is to create a plan to complete the task.
The plan must adapt with the user request and the conversation. It's fine to just start with needed tasks first and asking user for the next step approval.
The tools available are:
{tools_str}
Conversation: {chat_history}
Overall Task: {task}
"""


class ExecutePlanEvent(Event):
pass
Expand Down Expand Up @@ -62,14 +75,21 @@ def __init__(
tools: List[BaseTool] | None = None,
timeout: float = 360.0,
refine_plan: bool = False,
chat_history: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> None:
super().__init__(*args, timeout=timeout, **kwargs)
self.name = name
self.refine_plan = refine_plan
self.chat_history = chat_history

self.tools = tools or []
self.planner = Planner(llm=llm, tools=self.tools, verbose=self._verbose)
self.planner = Planner(
llm=llm,
tools=self.tools,
initial_plan_prompt=INITIAL_PLANNER_PROMPT,
verbose=self._verbose,
)
# The executor is keeping the memory of all tool calls and decides to call the right tool for the task
self.executor = FunctionCallingAgent(
name="executor",
Expand All @@ -89,7 +109,9 @@ async def create_plan(
ctx.data["streaming"] = getattr(ev, "streaming", False)
ctx.data["task"] = ev.input

plan_id, plan = await self.planner.create_plan(input=ev.input)
plan_id, plan = await self.planner.create_plan(
input=ev.input, chat_history=self.chat_history
)
ctx.data["act_plan_id"] = plan_id

# inform about the new plan
Expand Down Expand Up @@ -213,7 +235,9 @@ def __init__(
plan_refine_prompt = PromptTemplate(plan_refine_prompt)
self.plan_refine_prompt = plan_refine_prompt

async def create_plan(self, input: str) -> Tuple[str, Plan]:
async def create_plan(
self, input: str, chat_history: Optional[List[ChatMessage]] = None
) -> Tuple[str, Plan]:
tools = self.tools
tools_str = ""
for tool in tools:
Expand All @@ -225,6 +249,7 @@ async def create_plan(self, input: str) -> Tuple[str, Plan]:
self.initial_plan_prompt,
tools_str=tools_str,
task=input,
chat_history=chat_history,
)
except (ValueError, ValidationError):
if self.verbose:
Expand Down
2 changes: 1 addition & 1 deletion templates/types/multiagent/fastapi/app/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def chat(
):
try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
messages = data.get_history_messages(include_agent_messages=True)
# TODO: generate filters based on doc_ids
# for now just use all documents
# doc_ids = data.get_chat_document_ids()
Expand Down
49 changes: 44 additions & 5 deletions templates/types/multiagent/fastapi/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
from typing import Any, Dict, List, Literal, Optional

from app.config import DATA_DIR
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel, Field, validator
from pydantic.alias_generators import to_camel

from app.config import DATA_DIR

logger = logging.getLogger("uvicorn")


Expand Down Expand Up @@ -50,9 +49,14 @@ class Config:
alias_generator = to_camel


class AgentAnnotation(BaseModel):
agent: str
text: str


class Annotation(BaseModel):
type: str
data: AnnotationFileData | List[str]
data: AnnotationFileData | List[str] | AgentAnnotation

def to_content(self) -> str | None:
if self.type == "document_file":
Expand Down Expand Up @@ -119,14 +123,49 @@ def get_last_message_content(self) -> str:
break
return message_content

def get_history_messages(self) -> List[ChatMessage]:
def _get_agent_messages(self, max_messages: int = 5) -> List[str]:
"""
Construct agent messages from the annotations in the chat messages
"""
agent_messages = []
for message in self.messages:
if (
message.role == MessageRole.ASSISTANT
and message.annotations is not None
):
for annotation in message.annotations:
if annotation.type == "agent" and isinstance(
annotation.data, AgentAnnotation
):
text = annotation.data.text
if not text.startswith("Finished task"):
agent_messages.append(
f"\nAgent: {annotation.data.agent}\nsaid: {text}\n"
)
if len(agent_messages) >= max_messages:
break
return agent_messages

def get_history_messages(
self, include_agent_messages: bool = False
) -> List[ChatMessage]:
"""
Get the history messages
"""
return [
chat_messages = [
ChatMessage(role=message.role, content=message.content)
for message in self.messages[:-1]
]
if include_agent_messages:
agent_messages = self._get_agent_messages(max_messages=5)
if len(agent_messages) > 0:
message = ChatMessage(
role=MessageRole.ASSISTANT,
content="Previous agent events: \n" + "\n".join(agent_messages),
)
chat_messages.append(message)

return chat_messages

def is_last_message_from_user(self) -> bool:
return self.messages[-1].role == MessageRole.USER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ def create_orchestrator(chat_history: Optional[List[ChatMessage]] = None):
return AgentOrchestrator(
agents=[writer, reviewer, researcher, publisher],
refine_plan=False,
chat_history=chat_history,
)

0 comments on commit 8696146

Please sign in to comment.