Skip to content

Commit

Permalink
Async data calls (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Sep 25, 2024
1 parent cc86118 commit b777a29
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 106 deletions.
60 changes: 33 additions & 27 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
)
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, get_schema, render_template, retry_llm_output,
clean_sql, describe_data, get_data, get_pipeline, get_schema,
render_template, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -274,7 +275,7 @@ async def _system_prompt_with_context(
context = f"Available tables: {', '.join(closest_tables)}"
else:
memory["current_table"] = table = memory.get("current_table", tables[0])
schema = get_schema(memory["current_source"], table)
schema = await get_schema(memory["current_source"], table)
if schema:
context = f"{table} with schema: {schema}"

Expand Down Expand Up @@ -389,7 +390,7 @@ async def answer(self, messages: list | str):
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"
Expand Down Expand Up @@ -435,12 +436,12 @@ async def answer(self, messages: list | str):
get_kwargs['sql_transforms'] = [SQLLimit(limit=1_000_000)]
memory["current_source"] = source
memory["current_table"] = table
memory["current_pipeline"] = pipeline = Pipeline(
memory["current_pipeline"] = pipeline = await get_pipeline(
source=source, table=table, **get_kwargs
)
df = pipeline.data
df = await get_data(pipeline)
if len(df) > 0:
memory["current_data"] = describe_data(df)
memory["current_data"] = await describe_data(df)
if self.debug:
print(f"{self.name} thinks that the user is talking about {table=!r}.")
return pipeline
Expand Down Expand Up @@ -581,7 +582,7 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
# Get validated query
sql_query = sql_expr_source.tables[expr_slug]
sql_transforms = [SQLLimit(limit=1_000_000)]
pipeline = Pipeline(
pipeline = await get_pipeline(
source=sql_expr_source, table=expr_slug, sql_transforms=sql_transforms
)
except InstructorRetryException as e:
Expand All @@ -605,9 +606,9 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
step.status = "failed"
raise e

df = pipeline.data
df = await get_data(pipeline)
if len(df) > 0:
memory["current_data"] = describe_data(df)
memory["current_data"] = await describe_data(df)

memory["available_sources"].append(sql_expr_source)
memory["current_source"] = sql_expr_source
Expand Down Expand Up @@ -701,7 +702,7 @@ async def answer(self, messages: list | str):
if not hasattr(source, "get_sql_expr"):
return None

schema = get_schema(source, table, include_min_max=False)
schema = await get_schema(source, table, include_min_max=False)
join_required = await self.check_join_required(messages, schema, table)
if join_required:
tables_to_source = await self.find_join_tables(messages)
Expand All @@ -713,7 +714,7 @@ async def answer(self, messages: list | str):
if source_table == table:
table_schema = schema
else:
table_schema = get_schema(source, source_table, include_min_max=False)
table_schema = await get_schema(source, source_table, include_min_max=False)
table_schemas[source_table] = {
"schema": yaml.dump(table_schema),
"sql": source.get_sql_expr(source_table)
Expand Down Expand Up @@ -754,13 +755,14 @@ async def answer(self, messages: list | str) -> Transform:
if "current_pipeline" in memory:
pipeline = memory["current_pipeline"]
else:
pipeline = Pipeline(
pipeline = await get_pipeline(
source=memory["current_source"],
table=memory["current_table"],
)
memory["current_pipeline"] = pipeline
pipeline._update_data(force=True)
memory["current_data"] = describe_data(pipeline.data)
await asyncio.to_thread(pipeline._update_data, force=True)
data = await get_data(pipeline)
memory["current_data"] = await describe_data(data)
return pipeline

async def invoke(self, messages: list | str):
Expand Down Expand Up @@ -867,7 +869,7 @@ async def _construct_transform(
self, messages: list | str, transform: type[Transform], system_prompt: str
) -> Transform:
excluded = transform._internal_params + ["controls", "type"]
schema = get_schema(memory["current_pipeline"])
schema = await get_schema(memory["current_pipeline"])
table = memory["current_table"]
model = param_to_pydantic(transform, excluded=excluded, schema=schema)[
transform.__name__
Expand Down Expand Up @@ -912,8 +914,9 @@ async def answer(self, messages: list | str) -> Transform:
else:
pipeline.add_transform(transform)

pipeline._update_data(force=True)
memory["current_data"] = describe_data(pipeline.data)
await asyncio.to_thread(pipeline._update_data, force=True)
data = await get_data(pipeline)
memory["current_data"] = await describe_data(data)
return pipeline

async def invoke(self, messages: list | str):
Expand All @@ -927,15 +930,15 @@ class BaseViewAgent(LumenBaseAgent):

provides = param.List(default=["current_plot"], readonly=True)

def _extract_spec(self, model: BaseModel):
async def _extract_spec(self, model: BaseModel):
return dict(model)

async def answer(self, messages: list | str) -> hvPlotUIView:
pipeline = memory["current_pipeline"]

# Write prompts
system_prompt = await self._system_prompt_with_context(messages)
schema = get_schema(pipeline, include_min_max=False)
schema = await get_schema(pipeline, include_min_max=False)
view_prompt = render_template(
"plot_agent.jinja2",
schema=yaml.dump(schema),
Expand All @@ -951,7 +954,7 @@ async def answer(self, messages: list | str) -> hvPlotUIView:
system=system_prompt + view_prompt,
response_model=self._get_model(schema),
)
spec = self._extract_spec(output)
spec = await self._extract_spec(output)
chain_of_thought = spec.pop("chain_of_thought")
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
Expand Down Expand Up @@ -1002,7 +1005,7 @@ def _get_model(cls, schema):
})
return model[cls.view_type.__name__]

def _extract_spec(self, model):
async def _extract_spec(self, model):
pipeline = memory["current_pipeline"]
spec = {
key: val for key, val in dict(model).items()
Expand All @@ -1014,7 +1017,8 @@ def _extract_spec(self, model):

# Add defaults
spec["responsive"] = True
if len(pipeline.data) > 20000 and spec["kind"] in ("line", "scatter", "points"):
data = await get_data(pipeline)
if len(data) > 20000 and spec["kind"] in ("line", "scatter", "points"):
spec["rasterize"] = True
spec["cnorm"] = "log"
return spec
Expand All @@ -1039,7 +1043,7 @@ class VegaLiteAgent(BaseViewAgent):
def _get_model(cls, schema):
return VegaLiteSpec

def _extract_spec(self, model):
async def _extract_spec(self, model):
vega_spec = json.loads(model.json_spec)
if "$schema" not in vega_spec:
vega_spec["$schema"] = "https://vega.github.io/schema/vega-lite/v5.json"
Expand Down Expand Up @@ -1092,7 +1096,7 @@ async def _system_prompt_with_context(

async def answer(self, messages: list | str, agents: list[Agent] | None = None):
pipeline = memory['current_pipeline']
analyses = {a.name: a for a in self.analyses if a.applies(pipeline)}
analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)}
if not analyses:
print("NONE found...")
return None
Expand Down Expand Up @@ -1125,8 +1129,10 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None):
with self.interface.add_step(title="Creating view...", user="Assistant") as step:
await asyncio.sleep(0.1) # necessary to give it time to render before calling sync function...
analysis_callable = analyses[analysis_name].instance(agents=agents)

data = await get_data(pipeline)
for field in analysis_callable._field_params:
analysis_callable.param[field].objects = list(pipeline.data.columns)
analysis_callable.param[field].objects = list(data.columns)
memory["current_analysis"] = analysis_callable

if analysis_callable.autorun:
Expand All @@ -1143,8 +1149,8 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None):
# Ensure current_data reflects processed pipeline
if pipeline is not memory['current_pipeline']:
pipeline = memory['current_pipeline']
if len(pipeline.data) > 0:
memory["current_data"] = describe_data(pipeline.data)
if len(data) > 0:
memory["current_data"] = await describe_data(data)
yaml_spec = yaml.dump(spec)
step.stream(f"Generated view\n```yaml\n{yaml_spec}\n```")
step.success_title = "Generated view"
Expand Down
12 changes: 7 additions & 5 deletions lumen/ai/analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import panel as pn
import param

from lumen.ai.utils import get_data

from ..base import Component
from .controls import SourceControls
from .memory import memory
from .utils import get_schema


class Analysis(param.ParameterizedFunction):
Expand Down Expand Up @@ -34,13 +35,14 @@ class Analysis(param.ParameterizedFunction):
_field_params = []

@classmethod
def applies(cls, pipeline) -> bool:
async def applies(cls, pipeline) -> bool:
applies = True
data = await get_data(pipeline)
for col in cls.columns:
if isinstance(col, tuple):
applies &= any(c in pipeline.data.columns for c in col)
applies &= any(c in data.columns for c in col)
else:
applies &= col in pipeline.data.columns
applies &= col in data.columns
return applies

def controls(self):
Expand Down Expand Up @@ -80,7 +82,7 @@ def controls(self):
table = memory.get("current_table")
self._previous_source = source
self._previous_table = table
columns = list(get_schema(source, table=table).keys())
columns = list(source.get_schema(table).keys())
index_col = pn.widgets.AutocompleteInput.from_param(
self.param.index_col, options=columns, name="Join on",
placeholder="Start typing column name", search_strategy="includes",
Expand Down
10 changes: 5 additions & 5 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def use_suggestion(event):
else:
return
await agent.invoke([{'role': 'user', 'content': contents}], agents=self.agents)
self._add_analysis_suggestions()
await self._add_analysis_suggestions()
else:
self.interface.send(contents)

Expand Down Expand Up @@ -233,13 +233,13 @@ async def run_demo(event):
self.interface.param.watch(hide_suggestions, "objects")
return message

def _add_analysis_suggestions(self):
async def _add_analysis_suggestions(self):
pipeline = memory['current_pipeline']
current_analysis = memory.get("current_analysis")
allow_consecutive = getattr(current_analysis, '_consecutive_calls', True)
applicable_analyses = []
for analysis in self._analyses:
if analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)):
if await analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)):
applicable_analyses.append(analysis)
self._add_suggestions_to_footer(
[f"Apply {analysis.__name__}" for analysis in applicable_analyses],
Expand All @@ -263,7 +263,7 @@ async def _invalidate_memory(self, messages):
raise KeyError(f'Table {table} could not be found in available sources.')

try:
spec = get_schema(source, table=table, include_count=True)
spec = await get_schema(source, table=table, include_count=True)
except Exception:
# If the selected table cannot be fetched we should invalidate it
spec = None
Expand Down Expand Up @@ -482,7 +482,7 @@ async def invoke(self, messages: list | str) -> str:
await agent.invoke(messages[-context_length:], **kwargs)
self._current_agent.object = "## No agent active"
if "current_pipeline" in agent.provides:
self._add_analysis_suggestions()
await self._add_analysis_suggestions()
print("\033[92mDONE\033[0m", "\n\n")

def controls(self):
Expand Down
Loading

0 comments on commit b777a29

Please sign in to comment.