Skip to content

Commit

Permalink
Merge branch 'main' into pgvector_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda authored Oct 23, 2024
2 parents 85586ac + ed26e2a commit 4915144
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 62 deletions.
2 changes: 1 addition & 1 deletion papers/ai-sql-accuracy-2023-08-17.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ By providing just those 3 example queries, we see substantial improvements to th

Enterprise data warehouses often contain 100s (or even 1000s) of tables, and an order of magnitude more queries that cover all the use cases within their organizations. Given the limited size of the context windows of modern LLMs, we can’t just shove all the prior queries and schema definitions into the prompt.

Our final approach to context is a more sophisticated ML approach - load embeddings of prior queries and the table schemas into a vector database, and only choose the most relevant queries / tables to the question asked. Here's a diagram of what we are doing - note the contextual relevance search in the red box -
Our final approach to context is a more sophisticated ML approach - load embeddings of prior queries and the table schemas into a vector database, and only choose the most relevant queries / tables to the question asked. Here's a diagram of what we are doing - note the contextual relevance search in the green box -

![](https://raw.githubusercontent.com/vanna-ai/vanna/main/papers/img/using-contextually-relevant-examples.png)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "botocore"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand Down
49 changes: 34 additions & 15 deletions src/vanna/google/bigquery_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import os
import uuid
from typing import List, Optional
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel
)

import pandas as pd
from google.cloud import bigquery
Expand All @@ -23,17 +27,15 @@ def __init__(self, config: dict, **kwargs):
or set as an environment variable, assign it.
"""
print("Configuring genai")
self.type = "GEMINI"
import google.generativeai as genai

genai.configure(api_key=config["api_key"])

self.genai = genai
else:
self.type = "VERTEX_AI"
# Authenticate using VertexAI
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel,
)

if self.config.get("project_id"):
self.project_id = self.config.get("project_id")
Expand Down Expand Up @@ -139,25 +141,42 @@ def fetch_similar_training_data(self, training_data_type: str, question: str, n_
results = self.conn.query(query).result().to_dataframe()
return results

def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
def get_embeddings(self, data: str, task: str) -> List[float]:
embeddings = None

if self.type == "VERTEX_AI":
input = [TextEmbeddingInput(data, task)]
model = TextEmbeddingModel.from_pretrained("text-embedding-004")

result = model.get_embeddings(input)

if len(result) > 0:
embeddings = result[0].values
else:
# Use Gemini Consumer API
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_query")
task_type=task)

if 'embedding' in result:
return result['embedding']
if 'embedding' in result:
embeddings = result['embedding']

return embeddings

def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
result = self.get_embeddings(data, "RETRIEVAL_QUERY")

if result != None:
return result
else:
raise ValueError("No embeddings returned")

def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_document")
result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")

if 'embedding' in result:
return result['embedding']
if result != None:
return result
else:
raise ValueError("No embeddings returned")

Expand Down
4 changes: 2 additions & 2 deletions src/vanna/google/gemini_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, config=None):
if "model_name" in config:
model_name = config["model_name"]
else:
model_name = "gemini-1.0-pro"
model_name = "gemini-1.5-pro"

self.google_api_key = None

Expand All @@ -30,7 +30,7 @@ def __init__(self, config=None):
self.chat_model = genai.GenerativeModel(model_name)
else:
# Authenticate using VertexAI
from vertexai.preview.generative_models import GenerativeModel
from vertexai.generative_models import GenerativeModel
self.chat_model = GenerativeModel(model_name)

def system_message(self, message: str) -> any:
Expand Down
82 changes: 39 additions & 43 deletions tests/test_pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,48 @@

from dotenv import load_dotenv

from vanna.pgvector import PG_VectorStore
from vanna.openai import OpenAI_Chat

# from vanna.pgvector import PG_VectorStore
# from vanna.openai import OpenAI_Chat

# assume .env file placed next to file with provided env vars
load_dotenv()


def get_vanna_connection_string():
server = os.environ.get("PG_SERVER")
driver = "psycopg"
port = os.environ.get("PG_PORT", 5432)
database = os.environ.get("PG_DATABASE")
username = os.environ.get("PG_USERNAME")
password = os.environ.get("PG_PASSWORD")

return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"


def test_pgvector_e2e():
# configure Vanna to use OpenAI and PGVector
class VannaCustom(PG_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
PG_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
# def get_vanna_connection_string():
# server = os.environ.get("PG_SERVER")
# driver = "psycopg"
# port = os.environ.get("PG_PORT", 5432)
# database = os.environ.get("PG_DATABASE")
# username = os.environ.get("PG_USERNAME")
# password = os.environ.get("PG_PASSWORD")

# def test_pgvector_e2e():
# # configure Vanna to use OpenAI and PGVector
# class VannaCustom(PG_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# PG_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)

vn = VannaCustom(config={
'api_key': os.environ['OPENAI_API_KEY'],
'model': 'gpt-3.5-turbo',
"connection_string": get_vanna_connection_string(),
})

# connect to SQLite database
vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# train Vanna on DDLs
df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
for ddl in df_ddl['sql'].to_list():
vn.train(ddl=ddl)
assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default
# vn = VannaCustom(config={
# 'api_key': os.environ['OPENAI_API_KEY'],
# 'model': 'gpt-3.5-turbo',
# "connection_string": get_vanna_connection_string(),
# })

# # connect to SQLite database
# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# # train Vanna on DDLs
# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
# for ddl in df_ddl['sql'].to_list():
# vn.train(ddl=ddl)
# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default

question = "What are the top 7 customers by sales?"
sql = vn.generate_sql(question)
df = vn.run_sql(sql)
assert len(df) == 7

# test if Vanna can generate an answer
answer = vn.ask(question)
assert answer is not None
# question = "What are the top 7 customers by sales?"
# sql = vn.generate_sql(question)
# df = vn.run_sql(sql)
# assert len(df) == 7

# # test if Vanna can generate an answer
# answer = vn.ask(question)
# assert answer is not None

0 comments on commit 4915144

Please sign in to comment.