From 2ff2d865570fe17c6a9a47515a0bd89bdb34a5fc Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Wed, 8 May 2024 13:14:56 -0500 Subject: [PATCH 1/9] feat: pinecone vectorstore --- .gitignore | 1 + pyproject.toml | 1 + src/vanna/pinecone/__init__.py | 3 + src/vanna/pinecone/pinecone_vector.py | 275 ++++++++++++++++++++++++++ 4 files changed, 280 insertions(+) create mode 100644 src/vanna/pinecone/__init__.py create mode 100644 src/vanna/pinecone/pinecone_vector.py diff --git a/.gitignore b/.gitignore index 0bbaa6e5..1bfeb59c 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ dist htmlcov chroma.sqlite3 *.bin +env diff --git a/pyproject.toml b/pyproject.toml index d1b2d826..15cd6ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,5 +45,6 @@ zhipuai = ["zhipuai"] ollama = ["ollama", "httpx"] qdrant = ["qdrant-client", "fastembed"] vllm = ["vllm"] +pinecone = ["pinecone-client", "fastembed"] opensearch = ["opensearch-py", "opensearch-dsl"] hf = ["transformers"] diff --git a/src/vanna/pinecone/__init__.py b/src/vanna/pinecone/__init__.py new file mode 100644 index 00000000..1a1b98c5 --- /dev/null +++ b/src/vanna/pinecone/__init__.py @@ -0,0 +1,3 @@ +from .pinecone_vector import PineconeDB_VectorStore + +__all__ = ["PineconeDB_VectorStore"] diff --git a/src/vanna/pinecone/pinecone_vector.py b/src/vanna/pinecone/pinecone_vector.py new file mode 100644 index 00000000..dfc6a49d --- /dev/null +++ b/src/vanna/pinecone/pinecone_vector.py @@ -0,0 +1,275 @@ +import json +from typing import List + +from pinecone import Pinecone, PodSpec, ServerlessSpec +import pandas as pd +from ..base import VannaBase +from ..utils import deterministic_uuid + +from fastembed import TextEmbedding + + +class PineconeDB_VectorStore(VannaBase): + """ + Vectorstore using PineconeDB + + Args: + config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config. + - client (Pinecone, optional): Pinecone client. Defaults to None. + - api_key (str, optional): Pinecone API key. Defaults to None. + - n_results (int, optional): Number of results to return. Defaults to 10. + - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5. + - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5". + - documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation". + - distance_metric (str, optional): Distance metric to use. Defaults to "cosine". + - ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl". + - sql_namespace (str, optional): Namespace for SQL. Defaults to "sql". + - index_name (str, optional): Name of the index. Defaults to "vanna-index". + - metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}. + - server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod". + - podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config). + - serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2"). + Raises: + ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod". + """ + + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + if config is None: + raise ValueError( + "config is required, pass either a Pinecone client or an API key in the config." + ) + client = config.get("client") + api_key = config.get("api_key") + if not api_key and not client: + raise ValueError( + "api_key is required in config or pass a configured client" + ) + if not client and api_key: + self._client = Pinecone(api_key=api_key) + elif not isinstance(client, Pinecone): + raise ValueError("client must be an instance of Pinecone") + else: + self._client = client + + self.n_results = config.get("n_results", 10) + self.dimensions = config.get("dimensions", 384) + self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") + self.documentation_namespace = config.get( + "documentation_namespace", "documentation" + ) + self.distance_metric = config.get("distance_metric", "cosine") + self.ddl_namespace = config.get("ddl_namespace", "ddl") + self.sql_namespace = config.get("sql_namespace", "sql") + self.index_name = config.get("index_name", "vanna-index") + self.metadata_config = config.get("metadata_config", {}) + self.server_type = config.get("server_type", "serverless") + if self.server_type not in ["serverless", "pod"]: + raise ValueError("server_type must be either 'serverless' or 'pod'") + self.podspec = config.get( + "podspec", + PodSpec( + environment="us-west-2", + pod_type="p1.x1", + metadata_config=self.metadata_config, + ), + ) + self.serverless_spec = config.get( + "serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2") + ) + self._setup_index() + + def _set_index_host(self, host: str) -> None: + self.Index = self._client.Index(host=host) + + def _setup_index(self): + existing_indexes = self._get_indexes() + if self.index_name not in existing_indexes and self.server_type == "serverless": + self._client.create_index( + name=self.index_name, + dimension=self.dimensions, + metric=self.distance_metric, + spec=self.serverless_spec, + ) + pinecone_index_host = self._client.describe_index(self.index_name)["host"] + self._set_index_host(pinecone_index_host) + elif self.index_name not in existing_indexes and self.server_type == "pod": + self._client.create_index( + name=self.index_name, + dimension=self.dimensions, + metric=self.distance_metric, + spec=self.podspec, + ) + pinecone_index_host = self._client.describe_index(self.index_name)["host"] + self._set_index_host(pinecone_index_host) + else: + pinecone_index_host = self._client.describe_index(self.index_name)["host"] + self._set_index_host(pinecone_index_host) + + def _get_indexes(self): + return [index["name"] for index in self._client.list_indexes()] + + def _check_if_embedding_exists(self, id: str, namespace: str) -> bool: + fetch_response = self.Index.fetch(ids=[id], namespace=namespace) + if fetch_response["vectors"] == {}: + return False + return True + + def add_ddl(self, ddl: str, **kwargs) -> str: + id = deterministic_uuid(ddl) + "-ddl" + if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace): + print(f"DDL with id: {id} already exists in the index. Skipping...") + return id + self.Index.upsert( + vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})], + namespace=self.ddl_namespace, + ) + return id + + def add_documentation(self, doc: str, **kwargs) -> str: + id = deterministic_uuid(doc) + "-doc" + + if self._check_if_embedding_exists( + id=id, namespace=self.documentation_namespace + ): + print( + f"Documentation with id: {id} already exists in the index. Skipping..." + ) + return id + self.Index.upsert( + vectors=[(id, self.generate_embedding(doc), {"documentation": doc})], + namespace=self.documentation_namespace, + ) + return id + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + question_sql_json = json.dumps( + { + "question": question, + "sql": sql, + }, + ensure_ascii=False, + ) + id = deterministic_uuid(question_sql_json) + "-sql" + if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace): + print( + f"Question-SQL with id: {id} already exists in the index. Skipping..." + ) + return id + self.Index.upsert( + vectors=[ + ( + id, + self.generate_embedding(question_sql_json), + {"sql": question_sql_json}, + ) + ], + namespace=self.sql_namespace, + ) + return id + + def get_related_ddl(self, question: str, **kwargs) -> list: + res = self.Index.query( + namespace=self.ddl_namespace, + vector=self.generate_embedding(question), + top_k=self.n_results, + include_values=True, + include_metadata=True, + ) + return [match["metadata"]["ddl"] for match in res["matches"]] if res else [] + + def get_related_documentation(self, question: str, **kwargs) -> list: + res = self.Index.query( + namespace=self.documentation_namespace, + vector=self.generate_embedding(question), + top_k=self.n_results, + include_values=True, + include_metadata=True, + ) + return ( + [match["metadata"]["documentation"] for match in res["matches"]] + if res + else [] + ) + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + res = self.Index.query( + namespace=self.sql_namespace, + vector=self.generate_embedding(question), + top_k=self.n_results, + include_values=True, + include_metadata=True, + ) + return ( + [ + { + key: value + for key, value in json.loads(match["metadata"]["sql"]).items() + } + for match in res["matches"] + ] + if res + else [] + ) + + def get_training_data(self, **kwargs) -> pd.DataFrame: + # Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector + df = pd.DataFrame() + namespaces = { + "sql": self.sql_namespace, + "ddl": self.ddl_namespace, + "documentation": self.documentation_namespace, + } + + for data_type, namespace in namespaces.items(): + data = self.Index.query( + top_k=10000, # max results that pinecone allows + namespace=namespace, + include_values=True, + include_metadata=True, + vector=[0.0] * self.dimensions, + ) + + if data is not None: + id_list = [match["id"] for match in data["matches"]] + content_list = [ + match["metadata"][data_type] for match in data["matches"] + ] + question_list = [ + ( + json.loads(match["metadata"][data_type])["question"] + if data_type == "sql" + else None + ) + for match in data["matches"] + ] + + df_data = pd.DataFrame( + { + "id": id_list, + "question": question_list, + "content": content_list, + } + ) + df_data["training_data_type"] = data_type + df = pd.concat([df, df_data]) + + return df + + def remove_training_data(self, id: str, **kwargs) -> bool: + if id.endswith("-sql"): + self.Index.delete(ids=[id], namespace=self.sql_namespace) + return True + elif id.endswith("-ddl"): + self.Index.delete(ids=[id], namespace=self.ddl_namespace) + return True + elif id.endswith("-doc"): + self.Index.delete(ids=[id], namespace=self.documentation_namespace) + return True + else: + return False + + def generate_embedding(self, data: str, **kwargs) -> List[float]: + embedding_model = TextEmbedding(model_name=self.fastembed_model) + embedding = next(embedding_model.embed(data)) + return embedding.tolist() From 1baf6c3c447e32356e66d58aa715287fa5877777 Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Wed, 8 May 2024 13:17:46 -0500 Subject: [PATCH 2/9] chore: removed env from gitignore, renamed to venv --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1bfeb59c..744428b2 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,6 @@ docs/*.html .tox/ notebooks/chroma.sqlite3 dist -.env *.sqlite htmlcov chroma.sqlite3 From 106bcba10d9cce61a75f27b23cbf1d7e7ba8da10 Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Wed, 8 May 2024 13:51:11 -0500 Subject: [PATCH 3/9] chore: return types --- src/vanna/pinecone/pinecone_vector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vanna/pinecone/pinecone_vector.py b/src/vanna/pinecone/pinecone_vector.py index dfc6a49d..1e82ecf2 100644 --- a/src/vanna/pinecone/pinecone_vector.py +++ b/src/vanna/pinecone/pinecone_vector.py @@ -82,7 +82,7 @@ def __init__(self, config=None): def _set_index_host(self, host: str) -> None: self.Index = self._client.Index(host=host) - def _setup_index(self): + def _setup_index(self) -> None: existing_indexes = self._get_indexes() if self.index_name not in existing_indexes and self.server_type == "serverless": self._client.create_index( @@ -106,7 +106,7 @@ def _setup_index(self): pinecone_index_host = self._client.describe_index(self.index_name)["host"] self._set_index_host(pinecone_index_host) - def _get_indexes(self): + def _get_indexes(self) -> list: return [index["name"] for index in self._client.list_indexes()] def _check_if_embedding_exists(self, id: str, namespace: str) -> bool: From 15289025bbc9ff112f45d3bcc43eb9f35f4c42df Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Wed, 8 May 2024 15:31:53 -0500 Subject: [PATCH 4/9] chore: fixed gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 744428b2..0bbaa6e5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,8 @@ docs/*.html .tox/ notebooks/chroma.sqlite3 dist +.env *.sqlite htmlcov chroma.sqlite3 *.bin -env From 974ecedff596dd63c764335ba38e62402dee5e60 Mon Sep 17 00:00:00 2001 From: huchengyi Date: Sun, 12 May 2024 19:02:22 +0800 Subject: [PATCH 5/9] Fixed a bug where the ask function did not return fig --- src/vanna/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index c13489b7..d7e1f487 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1701,7 +1701,7 @@ def ask( return None else: return sql, None, None - return sql, df, None + return sql, df, fig def train( self, From b2a5fcd7879d965188a40b42d5098360ff372260 Mon Sep 17 00:00:00 2001 From: peilongchencc Date: Mon, 20 May 2024 16:58:51 +0800 Subject: [PATCH 6/9] Fix string concatenation in initial_prompt --- src/vanna/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index c13489b7..2ca32d7d 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -555,7 +555,7 @@ def get_sql_prompt( """ if initial_prompt is None: - initial_prompt = f"You are a {self.dialect} expert. " + initial_prompt = f"You are a {self.dialect} expert. " + \ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " initial_prompt = self.add_ddl_to_prompt( From 19d0e5aa85d3dd97abe1f449d2c86fff72de2224 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Mon, 20 May 2024 08:29:43 -0600 Subject: [PATCH 7/9] tests --- .gitignore | 1 + pyproject.toml | 2 +- src/vanna/mock/__init__.py | 3 ++ src/vanna/mock/embedding.py | 11 ++++++++ src/vanna/mock/llm.py | 19 +++++++++++++ src/vanna/mock/vectordb.py | 55 +++++++++++++++++++++++++++++++++++++ tests/test_imports.py | 2 ++ tests/test_instantiation.py | 1 + tox.ini | 4 +-- 9 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 src/vanna/mock/__init__.py create mode 100644 src/vanna/mock/embedding.py create mode 100644 src/vanna/mock/llm.py create mode 100644 src/vanna/mock/vectordb.py create mode 100644 tests/test_instantiation.py diff --git a/.gitignore b/.gitignore index 0bbaa6e5..69698f91 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ dist htmlcov chroma.sqlite3 *.bin +.coverage.* diff --git a/pyproject.toml b/pyproject.toml index 15cd6ed5..4ebc14c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] diff --git a/src/vanna/mock/__init__.py b/src/vanna/mock/__init__.py new file mode 100644 index 00000000..ce493b97 --- /dev/null +++ b/src/vanna/mock/__init__.py @@ -0,0 +1,3 @@ +from .embedding import MockEmbedding +from .llm import MockLLM +from .vectordb import MockVectorDB diff --git a/src/vanna/mock/embedding.py b/src/vanna/mock/embedding.py new file mode 100644 index 00000000..a744b342 --- /dev/null +++ b/src/vanna/mock/embedding.py @@ -0,0 +1,11 @@ +from typing import List + +from ..base import VannaBase + + +class MockEmbedding(VannaBase): + def __init__(self, config=None): + pass + + def generate_embedding(self, data: str, **kwargs) -> List[float]: + return [1.0, 2.0, 3.0, 4.0, 5.0] diff --git a/src/vanna/mock/llm.py b/src/vanna/mock/llm.py new file mode 100644 index 00000000..a0196a28 --- /dev/null +++ b/src/vanna/mock/llm.py @@ -0,0 +1,19 @@ + +from ..base import VannaBase + + +class MockLLM(VannaBase): + def __init__(self, config=None): + pass + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + return "Mock LLM response" diff --git a/src/vanna/mock/vectordb.py b/src/vanna/mock/vectordb.py new file mode 100644 index 00000000..9259ff23 --- /dev/null +++ b/src/vanna/mock/vectordb.py @@ -0,0 +1,55 @@ +import pandas as pd + +from ..base import VannaBase + + +class MockVectorDB(VannaBase): + def __init__(self, config=None): + pass + + def _get_id(self, value: str, **kwargs) -> str: + # Hash the value and return the ID + return str(hash(value)) + + def add_ddl(self, ddl: str, **kwargs) -> str: + return self._get_id(ddl) + + def add_documentation(self, doc: str, **kwargs) -> str: + return self._get_id(doc) + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + return self._get_id(question) + + def get_related_ddl(self, question: str, **kwargs) -> list: + return [] + + def get_related_documentation(self, question: str, **kwargs) -> list: + return [] + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + return [] + + def get_training_data(self, **kwargs) -> pd.DataFrame: + return pd.DataFrame({'id': {0: '19546-ddl', + 1: '91597-sql', + 2: '133976-sql', + 3: '59851-doc', + 4: '73046-sql'}, + 'training_data_type': {0: 'ddl', + 1: 'sql', + 2: 'sql', + 3: 'documentation', + 4: 'sql'}, + 'question': {0: None, + 1: 'What are the top selling genres?', + 2: 'What are the low 7 artists by sales?', + 3: None, + 4: 'What is the total sales for each customer?'}, + 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)', + 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;', + 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;', + 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.', + 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}}) + + def remove_training_data(id: str, **kwargs) -> bool: + return True diff --git a/tests/test_imports.py b/tests/test_imports.py index 3141d37e..92db890d 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -12,6 +12,7 @@ def test_regular_imports(): from vanna.openai.openai_chat import OpenAI_Chat from vanna.openai.openai_embeddings import OpenAI_Embeddings from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore + from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore from vanna.remote import VannaDefault from vanna.vannadb.vannadb_vector import VannaDB_VectorStore from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat @@ -27,6 +28,7 @@ def test_shortcut_imports(): from vanna.ollama import Ollama from vanna.openai import OpenAI_Chat, OpenAI_Embeddings from vanna.opensearch import OpenSearch_VectorStore + from vanna.pinecone import PineconeDB_VectorStore from vanna.vannadb import VannaDB_VectorStore from vanna.vllm import Vllm from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings diff --git a/tests/test_instantiation.py b/tests/test_instantiation.py new file mode 100644 index 00000000..f6d58134 --- /dev/null +++ b/tests/test_instantiation.py @@ -0,0 +1 @@ +from vanna.mock import MockEmbedding, MockLLM, MockVectorDB diff --git a/tox.ini b/tox.ini index a406f326..15819568 100644 --- a/tox.ini +++ b/tox.ini @@ -23,8 +23,8 @@ deps= python-dotenv extras = all basepython = python -commands = - pytest -v --cov=tests/ --cov-report=term --cov-report=html +commands = + pytest -x -v --cov=tests/ --cov-report=term --cov-report=html [testenv:flake8] exclude = .tox/* From a72b842d420cf1fa061e5f97d45ea08051651ebb Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Mon, 20 May 2024 08:40:10 -0600 Subject: [PATCH 8/9] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4ebc14c6..466089c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.5.4" +version = "0.5.5" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] From 747169156c88d3f9137b6d2d784052b934b50156 Mon Sep 17 00:00:00 2001 From: ishita Date: Thu, 23 May 2024 22:31:42 +0530 Subject: [PATCH 9/9] Add auth-key support and validation to Vllm --- src/vanna/vllm/vllm.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/vanna/vllm/vllm.py b/src/vanna/vllm/vllm.py index 0dd67e4f..53990821 100644 --- a/src/vanna/vllm/vllm.py +++ b/src/vanna/vllm/vllm.py @@ -17,6 +17,11 @@ def __init__(self, config=None): else: self.model = config["model"] + if "auth-key" in config: + self.auth_key = config["auth-key"] + else: + self.auth_key = None + def system_message(self, message: str) -> any: return {"role": "system", "content": message} @@ -67,7 +72,17 @@ def submit_prompt(self, prompt, **kwargs) -> str: "messages": prompt, } - response = requests.post(url, json=data) + if self.auth_key is not None: + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.auth_key}' + } + + response = requests.post(url, headers=headers,json=data) + + + else: + response = requests.post(url, json=data) response_dict = response.json()