Skip to content

Commit

Permalink
Merge pull request #660 from edlouth/pgvector_fixes
Browse files Browse the repository at this point in the history
Pgvector fixes
  • Loading branch information
zainhoda authored Oct 23, 2024
2 parents ed26e2a + 4915144 commit ac1a841
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
32 changes: 10 additions & 22 deletions src/vanna/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ def __init__(self, config=None):
if config and "embedding_function" in config:
self.embedding_function = config.get("embedding_function")
else:
from sentence_transformers import SentenceTransformer
self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2")
from langchain_huggingface import HuggingFaceEmbeddings
self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

self.sql_vectorstore = PGVector(
self.sql_collection = PGVector(
embeddings=self.embedding_function,
collection_name="sql",
connection=self.connection_string,
)
self.ddl_vectorstore = PGVector(
self.ddl_collection = PGVector(
embeddings=self.embedding_function,
collection_name="ddl",
connection=self.connection_string,
)
self.documentation_vectorstore = PGVector(
self.documentation_collection = PGVector(
embeddings=self.embedding_function,
collection_name="documentation",
connection=self.connection_string,
Expand Down Expand Up @@ -94,16 +94,16 @@ def get_collection(self, collection_name):
case _:
raise ValueError("Specified collection does not exist.")

async def get_similar_question_sql(self, question: str) -> list:
def get_similar_question_sql(self, question: str) -> list:
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

async def get_related_ddl(self, question: str, **kwargs) -> list:
documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results)
def get_related_ddl(self, question: str, **kwargs) -> list:
documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
return [document.page_content for document in documents]

async def get_related_documentation(self, question: str, **kwargs) -> list:
documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results)
def get_related_documentation(self, question: str, **kwargs) -> list:
documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
return [document.page_content for document in documents]

def train(
Expand Down Expand Up @@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool:

def generate_embedding(self, *args, **kwargs):
pass

def submit_prompt(self, *args, **kwargs):
pass

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
39 changes: 32 additions & 7 deletions tests/test_pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,47 @@
from dotenv import load_dotenv

# from vanna.pgvector import PG_VectorStore
# from vanna.openai import OpenAI_Chat

# assume .env file placed next to file with provided env vars
load_dotenv()

# Removing thiese tests for now until the dependencies are sorted out
# def get_vanna_connection_string():
# server = os.environ.get("PG_SERVER")
# driver = "psycopg"
# port = 5434
# port = os.environ.get("PG_PORT", 5432)
# database = os.environ.get("PG_DATABASE")
# username = os.environ.get("PG_USERNAME")
# password = os.environ.get("PG_PASSWORD")

# return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"
# def test_pgvector_e2e():
# # configure Vanna to use OpenAI and PGVector
# class VannaCustom(PG_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# PG_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)

# vn = VannaCustom(config={
# 'api_key': os.environ['OPENAI_API_KEY'],
# 'model': 'gpt-3.5-turbo',
# "connection_string": get_vanna_connection_string(),
# })

# # connect to SQLite database
# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# # train Vanna on DDLs
# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
# for ddl in df_ddl['sql'].to_list():
# vn.train(ddl=ddl)
# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default

# question = "What are the top 7 customers by sales?"
# sql = vn.generate_sql(question)
# df = vn.run_sql(sql)
# assert len(df) == 7

# # test if Vanna can generate an answer
# answer = vn.ask(question)
# assert answer is not None

# def test_pgvector():
# connection_string = get_vanna_connection_string()
# pgclient = PG_VectorStore(config={"connection_string": connection_string})
# assert pgclient is not None

0 comments on commit ac1a841

Please sign in to comment.