Skip to content

Commit

Permalink
Merge branch 'main' into get-function
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda authored Jun 7, 2024
2 parents 2107116 + 246bbe5 commit 8c7c5b0
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dist
htmlcov
chroma.sqlite3
*.bin
.coverage.*
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -45,5 +45,6 @@ zhipuai = ["zhipuai"]
ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client", "fastembed"]
vllm = ["vllm"]
pinecone = ["pinecone-client", "fastembed"]
opensearch = ["opensearch-py", "opensearch-dsl"]
hf = ["transformers"]
4 changes: 2 additions & 2 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def get_sql_prompt(
"""

if initial_prompt is None:
initial_prompt = f"You are a {self.dialect} expert. "
initial_prompt = f"You are a {self.dialect} expert. " + \
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "

initial_prompt = self.add_ddl_to_prompt(
Expand Down Expand Up @@ -1701,7 +1701,7 @@ def ask(
return None
else:
return sql, None, None
return sql, df, None
return sql, df, fig

def train(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/vanna/mock/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .embedding import MockEmbedding
from .llm import MockLLM
from .vectordb import MockVectorDB
11 changes: 11 additions & 0 deletions src/vanna/mock/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from ..base import VannaBase


class MockEmbedding(VannaBase):
def __init__(self, config=None):
pass

def generate_embedding(self, data: str, **kwargs) -> List[float]:
return [1.0, 2.0, 3.0, 4.0, 5.0]
19 changes: 19 additions & 0 deletions src/vanna/mock/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from ..base import VannaBase


class MockLLM(VannaBase):
def __init__(self, config=None):
pass

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
return "Mock LLM response"
55 changes: 55 additions & 0 deletions src/vanna/mock/vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pandas as pd

from ..base import VannaBase


class MockVectorDB(VannaBase):
def __init__(self, config=None):
pass

def _get_id(self, value: str, **kwargs) -> str:
# Hash the value and return the ID
return str(hash(value))

def add_ddl(self, ddl: str, **kwargs) -> str:
return self._get_id(ddl)

def add_documentation(self, doc: str, **kwargs) -> str:
return self._get_id(doc)

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return self._get_id(question)

def get_related_ddl(self, question: str, **kwargs) -> list:
return []

def get_related_documentation(self, question: str, **kwargs) -> list:
return []

def get_similar_question_sql(self, question: str, **kwargs) -> list:
return []

def get_training_data(self, **kwargs) -> pd.DataFrame:
return pd.DataFrame({'id': {0: '19546-ddl',
1: '91597-sql',
2: '133976-sql',
3: '59851-doc',
4: '73046-sql'},
'training_data_type': {0: 'ddl',
1: 'sql',
2: 'sql',
3: 'documentation',
4: 'sql'},
'question': {0: None,
1: 'What are the top selling genres?',
2: 'What are the low 7 artists by sales?',
3: None,
4: 'What is the total sales for each customer?'},
'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})

def remove_training_data(id: str, **kwargs) -> bool:
return True
3 changes: 3 additions & 0 deletions src/vanna/pinecone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pinecone_vector import PineconeDB_VectorStore

__all__ = ["PineconeDB_VectorStore"]
Loading

0 comments on commit 8c7c5b0

Please sign in to comment.