Skip to content

Commit

Permalink
Merge pull request #485 from vanna-ai/get-function
Browse files Browse the repository at this point in the history
Function RAG for SQL Generation
  • Loading branch information
zainhoda authored Jun 7, 2024
2 parents 246bbe5 + 8c7c5b0 commit 64dd560
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.5.5"
version = "0.6.0"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
26 changes: 26 additions & 0 deletions src/vanna/advanced/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod


class VannaAdvanced(ABC):
def __init__(self, config=None):
self.config = config

@abstractmethod
def get_function(self, question: str, additional_data: dict = {}) -> dict:
pass

@abstractmethod
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
pass

@abstractmethod
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
pass

@abstractmethod
def delete_function(self, function_name: str) -> bool:
pass

@abstractmethod
def get_all_functions(self) -> list:
pass
133 changes: 125 additions & 8 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import logging
import os
import sys
import uuid
from abc import ABC, abstractmethod
from functools import wraps

import flask
import requests
from flask import Flask, Response, jsonify, request
from flask import Flask, Response, jsonify, request, send_from_directory
from flask_sock import Sock

from .assets import css_content, html_content, js_content
Expand Down Expand Up @@ -151,7 +152,10 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True
summarization=True,
function_generation=True,
index_html_path=None,
assets_folder=None,
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Expand All @@ -176,6 +180,8 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
Returns:
None
Expand All @@ -202,6 +208,9 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
self.ask_results_correct = ask_results_correct
self.followup_questions = followup_questions
self.summarization = summarization
self.function_generation = function_generation and hasattr(vn, "get_function")
self.index_html_path = index_html_path
self.assets_folder = assets_folder

log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)
Expand Down Expand Up @@ -247,6 +256,7 @@ def get_config(user: any):
"ask_results_correct": self.ask_results_correct,
"followup_questions": self.followup_questions,
"summarization": self.summarization,
"function_generation": self.function_generation,
}

config = self.auth.override_config_for_user(user, config)
Expand Down Expand Up @@ -345,6 +355,56 @@ def generate_sql(user: any):
}
)

@self.flask_app.route("/api/v0/get_function", methods=["GET"])
@self.requires_auth
def get_function(user: any):
question = flask.request.args.get("question")

if question is None:
return jsonify({"type": "error", "error": "No question provided"})

if not hasattr(vn, "get_function"):
return jsonify({"type": "error", "error": "This setup does not support function generation."})

id = self.cache.generate_id(question=question)
function = vn.get_function(question=question)

if function is None:
return jsonify({"type": "error", "error": "No function found"})

if 'instantiated_sql' not in function:
self.vn.log(f"No instantiated SQL found for {question} in {function}")
return jsonify({"type": "error", "error": "No instantiated SQL found"})

self.cache.set(id=id, field="question", value=question)
self.cache.set(id=id, field="sql", value=function['instantiated_sql'])

if 'instantiated_post_processing_code' in function and function['instantiated_post_processing_code'] is not None and len(function['instantiated_post_processing_code']) > 0:
self.cache.set(id=id, field="plotly_code", value=function['instantiated_post_processing_code'])

return jsonify(
{
"type": "function",
"id": id,
"function": function,
}
)

@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
@self.requires_auth
def get_all_functions(user: any):
if not hasattr(vn, "get_all_functions"):
return jsonify({"type": "error", "error": "This setup does not support function generation."})

functions = vn.get_all_functions()

return jsonify(
{
"type": "functions",
"functions": functions,
}
)

@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
@self.requires_auth
@self.requires_cache(["sql"])
Expand Down Expand Up @@ -438,11 +498,18 @@ def generate_plotly_figure(user: any, id: str, df, question, sql):
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"

try:
code = vn.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
# If chart_instructions is not set then attempt to retrieve the code from the cache
if chart_instructions is None or len(chart_instructions) == 0:
code = self.cache.get(id=id, field="plotly_code")

if code is None:
code = vn.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
self.cache.set(id=id, field="plotly_code", value=code)

fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()

Expand Down Expand Up @@ -518,6 +585,49 @@ def add_training_data(user: any):
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route("/api/v0/create_function", methods=["POST"])
@self.requires_auth
def create_function(user: any):
question = flask.request.json.get("question")
sql = flask.request.json.get("sql")
id = flask.request.json.get("id")

plotly_code = self.cache.get(id=id, field="plotly_code")

if plotly_code is None:
plotly_code = ""

function_data = self.vn.create_function(question=question, sql=sql, plotly_code=plotly_code)

return jsonify(
{
"type": "function_template",
"id": id,
"function_template": function_data,
}
)

@self.flask_app.route("/api/v0/update_function", methods=["POST"])
@self.requires_auth
def update_function(user: any):
old_function_name = flask.request.json.get("old_function_name")
updated_function = flask.request.json.get("updated_function")

print("old_function_name", old_function_name)
print("updated_function", updated_function)

updated = vn.update_function(old_function_name=old_function_name, updated_function=updated_function)

return jsonify({"success": updated})

@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
@self.requires_auth
def delete_function(user: any):
function_name = flask.request.json.get("function_name")

return jsonify({"success": vn.delete_function(function_name=function_name)})


@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df", "question", "sql"])
Expand Down Expand Up @@ -616,6 +726,9 @@ def catch_all(catch_all):

@self.flask_app.route("/assets/<path:filename>")
def proxy_assets(filename):
if self.assets_folder:
return send_from_directory(self.assets_folder, filename)

if ".css" in filename:
return Response(css_content, mimetype="text/css")

Expand Down Expand Up @@ -663,6 +776,10 @@ def sock_log(ws):
@self.flask_app.route("/", defaults={"path": ""})
@self.flask_app.route("/<path:path>")
def hello(path: str):
if self.index_html_path:
directory = os.path.dirname(self.index_html_path)
filename = os.path.basename(self.index_html_path)
return send_from_directory(directory=directory, path=filename)
return html_content

def run(self, *args, **kwargs):
Expand Down Expand Up @@ -692,4 +809,4 @@ def run(self, *args, **kwargs):
print("Your app is running at:")
print("http://localhost:8084")

self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
52 changes: 36 additions & 16 deletions src/vanna/flask/assets.py

Large diffs are not rendered by default.

Loading

0 comments on commit 64dd560

Please sign in to comment.