Skip to content

Commit

Permalink
make typing into list
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Oct 29, 2024
1 parent 1dc6335 commit e53ccd8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
32 changes: 16 additions & 16 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __panel__(self):
return self.interface

async def _system_prompt_with_context(
self, messages: str, context: str = ""
self, messages: list, context: str = ""
) -> str:
system_prompt = self.system_prompt
if self.embeddings:
Expand Down Expand Up @@ -178,10 +178,10 @@ async def _select_table(self, tables):
self.interface.pop(-1)
return tables

async def requirements(self, messages: str):
async def requirements(self, messages: list):
return self.requires

async def answer(self, messages: str):
async def answer(self, messages: list):
system_prompt = await self._system_prompt_with_context(messages)

message = None
Expand All @@ -192,7 +192,7 @@ async def answer(self, messages: str):
output, replace=True, message=message, user=self.user, max_width=self._max_width
)

async def invoke(self, messages: str):
async def invoke(self, messages: list):
await self.answer(messages)


Expand Down Expand Up @@ -245,7 +245,7 @@ class ChatAgent(Agent):
requires = param.List(default=["current_source"], readonly=True)

@retry_llm_output()
async def requirements(self, messages: str, errors=None):
async def requirements(self, messages: list, errors=None):
if 'current_data' in memory:
return self.requires

Expand All @@ -270,7 +270,7 @@ async def requirements(self, messages: str, errors=None):
return self.requires

async def _system_prompt_with_context(
self, messages: str, context: str = ""
self, messages: list, context: str = ""
) -> str:
source = memory.get("current_source")
if not source:
Expand Down Expand Up @@ -321,7 +321,7 @@ class ChatDetailsAgent(ChatAgent):
)

async def _system_prompt_with_context(
self, messages: str, context: str = ""
self, messages: list, context: str = ""
) -> str:
system_prompt = self.system_prompt
topic = (await self.llm.invoke(
Expand Down Expand Up @@ -384,7 +384,7 @@ def _use_table(self, event):
table = self._df.iloc[event.row, 0]
self.interface.send(f"Show the table: {table!r}")

async def answer(self, messages: str):
async def answer(self, messages: list):
tables = []
for source in memory['available_sources']:
tables += source.get_tables()
Expand All @@ -407,7 +407,7 @@ async def answer(self, messages: str):
self.interface.stream(table_list, user="Lumen")
return tables

async def invoke(self, messages: str):
async def invoke(self, messages: list):
await self.answer(messages)


Expand Down Expand Up @@ -640,7 +640,7 @@ async def find_join_tables(self, messages: list):
tables_to_source[a_table] = a_source
return tables_to_source

async def answer(self, messages: str):
async def answer(self, messages: list):
"""
Steps:
1. Retrieve the current source and table from memory.
Expand Down Expand Up @@ -702,7 +702,7 @@ async def answer(self, messages: str):
print(sql_query)
return sql_query

async def invoke(self, messages: str):
async def invoke(self, messages: list):
sql_query = await self.answer(messages)
self._render_sql(sql_query)

Expand All @@ -716,7 +716,7 @@ class BaseViewAgent(LumenBaseAgent):
async def _extract_spec(self, model: BaseModel):
return dict(model)

async def answer(self, messages: str) -> hvPlotUIView:
async def answer(self, messages: list) -> hvPlotUIView:
pipeline = memory["current_pipeline"]

# Write prompts
Expand Down Expand Up @@ -746,7 +746,7 @@ async def answer(self, messages: str) -> hvPlotUIView:
memory["current_view"] = dict(spec, type=self.view_type)
return self.view_type(pipeline=pipeline, **spec)

async def invoke(self, messages: str):
async def invoke(self, messages: list):
view = await self.answer(messages)
self._render_lumen(view)

Expand Down Expand Up @@ -861,7 +861,7 @@ class AnalysisAgent(LumenBaseAgent):
_output_type = AnalysisOutput

async def _system_prompt_with_context(
self, messages: str, context: str = "", analyses: list[Analysis] = []
self, messages: list, context: str = "", analyses: list[Analysis] = []
) -> str:
system_prompt = self.system_prompt
for name, analysis in analyses.items():
Expand All @@ -880,7 +880,7 @@ async def _system_prompt_with_context(
system_prompt += f"\n### CONTEXT: {context}".strip()
return system_prompt

async def answer(self, messages: str, agents: list[Agent] | None = None):
async def answer(self, messages: list, agents: list[Agent] | None = None):
pipeline = memory['current_pipeline']
analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)}
if not analyses:
Expand Down Expand Up @@ -945,7 +945,7 @@ async def answer(self, messages: str, agents: list[Agent] | None = None):
view = None
return view

async def invoke(self, messages: str, agents=None):
async def invoke(self, messages: list, agents=None):
view = await self.answer(messages, agents=agents)
analysis = memory["current_analysis"]
if view is None and analysis.autorun:
Expand Down
6 changes: 3 additions & 3 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def _fill_model(self, messages, system, agent_model, errors=None):
)
return out

async def _choose_agent(self, messages: str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None):
async def _choose_agent(self, messages: list, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None):
if agents is None:
agents = self.agents
agents = [agent for agent in agents if await agent.applies()]
Expand Down Expand Up @@ -390,7 +390,7 @@ async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> lis
step.success_title = f"Solved a dependency with {output.agent}"
return agent_chain[::-1]+[(agent, (), None)]

async def _get_agent(self, messages: str):
async def _get_agent(self, messages: list):
if len(self.agents) == 1:
return self.agents[0]

Expand Down Expand Up @@ -448,7 +448,7 @@ def _serialize(self, obj, exclude_passwords=True):
obj = obj.value
return str(obj)

async def invoke(self, messages: str) -> str:
async def invoke(self, messages: list) -> str:
messages = self.interface.serialize(custom_serializer=self._serialize)[-4:]
invalidation_assessment = await self._invalidate_memory(messages[-2:])
context_length = 3
Expand Down
8 changes: 4 additions & 4 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _add_system_message(self, messages, system, input_kwargs):

async def invoke(
self,
messages: str,
messages: list,
system: str = "",
response_model: BaseModel | None = None,
allow_partial: bool = False,
Expand Down Expand Up @@ -89,7 +89,7 @@ def _get_delta(cls, chunk):

async def stream(
self,
messages: str,
messages: list,
system: str = "",
response_model: BaseModel | None = None,
field: str | None = None,
Expand Down Expand Up @@ -355,7 +355,7 @@ def _get_delta(cls, chunk):

async def invoke(
self,
messages: str,
messages: list,
system: str = "",
response_model: BaseModel | None = None,
allow_partial: bool = False,
Expand Down Expand Up @@ -465,7 +465,7 @@ def _add_system_message(self, messages, system, input_kwargs):

async def invoke(
self,
messages: str,
messages: list,
system: str = "",
response_model: BaseModel | None = None,
allow_partial: bool = False,
Expand Down

0 comments on commit e53ccd8

Please sign in to comment.