From 1dc63354fa662d433b479801cdbc6672c47e824c Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 29 Oct 2024 10:51:06 -0700 Subject: [PATCH] drop str message typing --- lumen/ai/agents.py | 40 ++++++++++++++++++++-------------------- lumen/ai/assistant.py | 6 +++--- lumen/ai/llm.py | 16 ++++------------ 3 files changed, 27 insertions(+), 35 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 336736e2..0456e4db 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -113,7 +113,7 @@ def __panel__(self): return self.interface async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: str, context: str = "" ) -> str: system_prompt = self.system_prompt if self.embeddings: @@ -178,10 +178,10 @@ async def _select_table(self, tables): self.interface.pop(-1) return tables - async def requirements(self, messages: list | str): + async def requirements(self, messages: str): return self.requires - async def answer(self, messages: list | str): + async def answer(self, messages: str): system_prompt = await self._system_prompt_with_context(messages) message = None @@ -192,7 +192,7 @@ async def answer(self, messages: list | str): output, replace=True, message=message, user=self.user, max_width=self._max_width ) - async def invoke(self, messages: list | str): + async def invoke(self, messages: str): await self.answer(messages) @@ -212,13 +212,13 @@ class SourceAgent(Agent): _extensions = ('filedropper',) - async def answer(self, messages: list | str): + async def answer(self, messages: list[str]): source_controls = SourceControls(multiple=True, replace_controls=True, select_existing=False) self.interface.send(source_controls, respond=False, user="SourceAgent") while not source_controls._add_button.clicks > 0: await asyncio.sleep(0.05) - async def invoke(self, messages: list[str] | str): + async def invoke(self, messages: list[str]): await self.answer(messages) @@ -245,7 +245,7 @@ class ChatAgent(Agent): requires = param.List(default=["current_source"], readonly=True) @retry_llm_output() - async def requirements(self, messages: list | str, errors=None): + async def requirements(self, messages: str, errors=None): if 'current_data' in memory: return self.requires @@ -270,7 +270,7 @@ async def requirements(self, messages: list | str, errors=None): return self.requires async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: str, context: str = "" ) -> str: source = memory.get("current_source") if not source: @@ -321,7 +321,7 @@ class ChatDetailsAgent(ChatAgent): ) async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: str, context: str = "" ) -> str: system_prompt = self.system_prompt topic = (await self.llm.invoke( @@ -384,7 +384,7 @@ def _use_table(self, event): table = self._df.iloc[event.row, 0] self.interface.send(f"Show the table: {table!r}") - async def answer(self, messages: list | str): + async def answer(self, messages: str): tables = [] for source in memory['available_sources']: tables += source.get_tables() @@ -407,7 +407,7 @@ async def answer(self, messages: list | str): self.interface.stream(table_list, user="Lumen") return tables - async def invoke(self, messages: list | str): + async def invoke(self, messages: str): await self.answer(messages) @@ -592,7 +592,7 @@ async def check_join_required(self, messages, schema, table): step.success_title = 'Query requires join' if join_required else 'No join required' return join_required - async def find_join_tables(self, messages: list | str): + async def find_join_tables(self, messages: list): multi_source = len(memory['available_sources']) > 1 if multi_source: available_tables = [ @@ -640,7 +640,7 @@ async def find_join_tables(self, messages: list | str): tables_to_source[a_table] = a_source return tables_to_source - async def answer(self, messages: list | str): + async def answer(self, messages: str): """ Steps: 1. Retrieve the current source and table from memory. @@ -702,7 +702,7 @@ async def answer(self, messages: list | str): print(sql_query) return sql_query - async def invoke(self, messages: list | str): + async def invoke(self, messages: str): sql_query = await self.answer(messages) self._render_sql(sql_query) @@ -716,7 +716,7 @@ class BaseViewAgent(LumenBaseAgent): async def _extract_spec(self, model: BaseModel): return dict(model) - async def answer(self, messages: list | str) -> hvPlotUIView: + async def answer(self, messages: str) -> hvPlotUIView: pipeline = memory["current_pipeline"] # Write prompts @@ -746,7 +746,7 @@ async def answer(self, messages: list | str) -> hvPlotUIView: memory["current_view"] = dict(spec, type=self.view_type) return self.view_type(pipeline=pipeline, **spec) - async def invoke(self, messages: list | str): + async def invoke(self, messages: str): view = await self.answer(messages) self._render_lumen(view) @@ -861,7 +861,7 @@ class AnalysisAgent(LumenBaseAgent): _output_type = AnalysisOutput async def _system_prompt_with_context( - self, messages: list | str, context: str = "", analyses: list[Analysis] = [] + self, messages: str, context: str = "", analyses: list[Analysis] = [] ) -> str: system_prompt = self.system_prompt for name, analysis in analyses.items(): @@ -880,7 +880,7 @@ async def _system_prompt_with_context( system_prompt += f"\n### CONTEXT: {context}".strip() return system_prompt - async def answer(self, messages: list | str, agents: list[Agent] | None = None): + async def answer(self, messages: str, agents: list[Agent] | None = None): pipeline = memory['current_pipeline'] analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)} if not analyses: @@ -888,7 +888,7 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): return None # Short cut analysis selection if there's an exact match - if isinstance(messages, list) and messages: + if len(messages): analysis = messages[0].get('content').replace('Apply ', '') if analysis in analyses: analyses = {analysis: analyses[analysis]} @@ -945,7 +945,7 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): view = None return view - async def invoke(self, messages: list | str, agents=None): + async def invoke(self, messages: str, agents=None): view = await self.answer(messages, agents=agents) analysis = memory["current_analysis"] if view is None and analysis.autorun: diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index eba2c8bc..4beda914 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -339,7 +339,7 @@ async def _fill_model(self, messages, system, agent_model, errors=None): ) return out - async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None): + async def _choose_agent(self, messages: str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None): if agents is None: agents = self.agents agents = [agent for agent in agents if await agent.applies()] @@ -390,7 +390,7 @@ async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> lis step.success_title = f"Solved a dependency with {output.agent}" return agent_chain[::-1]+[(agent, (), None)] - async def _get_agent(self, messages: list | str): + async def _get_agent(self, messages: str): if len(self.agents) == 1: return self.agents[0] @@ -448,7 +448,7 @@ def _serialize(self, obj, exclude_passwords=True): obj = obj.value return str(obj) - async def invoke(self, messages: list | str) -> str: + async def invoke(self, messages: str) -> str: messages = self.interface.serialize(custom_serializer=self._serialize)[-4:] invalidation_assessment = await self._invalidate_memory(messages[-2:]) context_length = 3 diff --git a/lumen/ai/llm.py b/lumen/ai/llm.py index b3b48068..b5c4783c 100644 --- a/lumen/ai/llm.py +++ b/lumen/ai/llm.py @@ -58,7 +58,7 @@ def _add_system_message(self, messages, system, input_kwargs): async def invoke( self, - messages: list | str, + messages: str, system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, @@ -66,8 +66,6 @@ async def invoke( **input_kwargs, ) -> BaseModel: system = system.strip().replace("\n\n", "\n") - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] messages, input_kwargs = self._add_system_message(messages, system, input_kwargs) kwargs = dict(self._client_kwargs) @@ -91,7 +89,7 @@ def _get_delta(cls, chunk): async def stream( self, - messages: list | str, + messages: str, system: str = "", response_model: BaseModel | None = None, field: str | None = None, @@ -357,16 +355,13 @@ def _get_delta(cls, chunk): async def invoke( self, - messages: list | str, + messages: str, system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, model_key: str = "default", **input_kwargs, ) -> BaseModel: - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - if messages[0]["role"] == "assistant": # Mistral cannot start with assistant messages = messages[1:] @@ -470,16 +465,13 @@ def _add_system_message(self, messages, system, input_kwargs): async def invoke( self, - messages: list | str, + messages: str, system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, model_key: str = "default", **input_kwargs, ) -> BaseModel: - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - # check that first message is user message; if not, insert empty message if messages[0]["role"] != "user": messages.insert(0, {"role": "user", "content": "--"})