From 8ba8f2109cb846c69b55d8dca292361fb13395d0 Mon Sep 17 00:00:00 2001 From: Mohammed Abbadi Date: Thu, 5 Sep 2024 10:02:29 +0300 Subject: [PATCH 1/3] Adding FAISS Vector Storge --- src/vanna/faiss/__init__.py | 1 + src/vanna/faiss/faiss.py | 158 ++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 src/vanna/faiss/__init__.py create mode 100644 src/vanna/faiss/faiss.py diff --git a/src/vanna/faiss/__init__.py b/src/vanna/faiss/__init__.py new file mode 100644 index 00000000..3bbf7885 --- /dev/null +++ b/src/vanna/faiss/__init__.py @@ -0,0 +1 @@ +from .faiss import FAISS \ No newline at end of file diff --git a/src/vanna/faiss/faiss.py b/src/vanna/faiss/faiss.py new file mode 100644 index 00000000..a6ce4717 --- /dev/null +++ b/src/vanna/faiss/faiss.py @@ -0,0 +1,158 @@ +import os +import json +import uuid +from typing import List, Dict, Any + +import faiss +import numpy as np +import pandas as pd + +from ..base import VannaBase +from ..exceptions import DependencyError + +class FAISS(VannaBase): + def __init__(self, config=None): + if config is None: + config = {} + + VannaBase.__init__(self, config=config) + + try: + import faiss + except ImportError: + raise DependencyError( + "FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'" + ) + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise DependencyError( + "SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'." + ) + + self.path = config.get("path", ".") + self.embedding_dim = config.get('embedding_dim', 384) + self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10)) + self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10)) + self.curr_client = config.get("client", "persistent") + + if self.curr_client == 'persistent': + self.sql_index = self._load_or_create_index('sql_index.faiss') + self.doc_index = self._load_or_create_index('doc_index.faiss') + elif self.curr_client == 'in-memory': + self.sql_index = faiss.IndexFlatL2(self.embedding_dim) + self.doc_index = faiss.IndexFlatL2(self.embedding_dim) + elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client): + self.sql_index = self.curr_client[0] + self.doc_index = self.curr_client[2] + else: + raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}") + + self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json') + self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json') + + model_name = config.get('embedding_model', 'all-MiniLM-L6-v2') + self.embedding_model = SentenceTransformer(model_name) + + def _load_or_create_index(self, filename): + filepath = os.path.join(self.path, filename) + if os.path.exists(filepath): + return faiss.read_index(filepath) + return faiss.IndexFlatL2(self.embedding_dim) + + def _load_or_create_metadata(self, filename): + filepath = os.path.join(self.path, filename) + if os.path.exists(filepath): + with open(filepath, 'r') as f: + return json.load(f) + return [] + + def _save_index(self, index, filename): + if self.curr_client == 'persistent': + filepath = os.path.join(self.path, filename) + faiss.write_index(index, filepath) + + def _save_metadata(self, metadata, filename): + if self.curr_client == 'persistent': + filepath = os.path.join(self.path, filename) + with open(filepath, 'w') as f: + json.dump(metadata, f) + + def _generate_embedding(self, data: str, **kwargs) -> List[float]: + embedding = self.embedding_model.encode(data) + assert embedding.shape[0] == self.embedding_dim, \ + f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}" + return embedding.tolist() + + def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str: + embedding = self._generate_embedding(text) + index.add(np.array([embedding], dtype=np.float32)) + entry_id = str(uuid.uuid4()) + metadata_list.append({"id": entry_id, **(extra_metadata or {})}) + return entry_id + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql}) + self._save_index(self.sql_index, 'sql_index.faiss') + self._save_metadata(self.sql_metadata, 'sql_metadata.json') + return entry_id + + def add_documentation(self, documentation: str, **kwargs) -> str: + entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation}) + self._save_index(self.doc_index, 'doc_index.faiss') + self._save_metadata(self.doc_metadata, 'doc_metadata.json') + return entry_id + + def _get_similar(self, index, metadata_list, text, n_results) -> list: + embedding = self._generate_embedding(text) + D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results) + return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]] + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql) + + def get_related_documentation(self, question: str, **kwargs) -> list: + return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)] + + def get_training_data(self, **kwargs) -> pd.DataFrame: + sql_data = pd.DataFrame(self.sql_metadata) + sql_data['training_data_type'] = 'sql' + + doc_data = pd.DataFrame(self.doc_metadata) + doc_data['training_data_type'] = 'documentation' + + return pd.concat([sql_data, doc_data], ignore_index=True) + + def remove_training_data(self, id: str, **kwargs) -> bool: + for metadata_list, index, index_name in [ + (self.sql_metadata, self.sql_index, 'sql_index.faiss'), + (self.doc_metadata, self.doc_index, 'doc_index.faiss') + ]: + for i, item in enumerate(metadata_list): + if item['id'] == id: + del metadata_list[i] + new_index = faiss.IndexFlatL2(self.embedding_dim) + embeddings = [self._generate_embedding(json.dumps(m)) for m in metadata_list] + if embeddings: + new_index.add(np.array(embeddings, dtype=np.float32)) + setattr(self, index_name.split('.')[0], new_index) + + if self.curr_client == 'persistent': + self._save_index(new_index, index_name) + self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json") + + return True + return False + + def remove_collection(self, collection_name: str) -> bool: + if collection_name in ["sql", "documentation"]: + setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)) + setattr(self, f"{collection_name}_metadata", []) + + if self.curr_client == 'persistent': + self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss") + self._save_metadata([], f"{collection_name}_metadata.json") + + return True + return False \ No newline at end of file From 5917e4972d9445d2618c0b9db31a2a39675426c7 Mon Sep 17 00:00:00 2001 From: Mohammed Abbadi Date: Thu, 5 Sep 2024 10:04:46 +0300 Subject: [PATCH 2/3] Update faiss.py --- src/vanna/faiss/faiss.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/vanna/faiss/faiss.py b/src/vanna/faiss/faiss.py index a6ce4717..12e15ec1 100644 --- a/src/vanna/faiss/faiss.py +++ b/src/vanna/faiss/faiss.py @@ -34,22 +34,27 @@ def __init__(self, config=None): self.path = config.get("path", ".") self.embedding_dim = config.get('embedding_dim', 384) self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10)) + self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10)) self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10)) self.curr_client = config.get("client", "persistent") if self.curr_client == 'persistent': self.sql_index = self._load_or_create_index('sql_index.faiss') + self.ddl_index = self._load_or_create_index('ddl_index.faiss') self.doc_index = self._load_or_create_index('doc_index.faiss') elif self.curr_client == 'in-memory': self.sql_index = faiss.IndexFlatL2(self.embedding_dim) + self.ddl_index = faiss.IndexFlatL2(self.embedding_dim) self.doc_index = faiss.IndexFlatL2(self.embedding_dim) elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client): self.sql_index = self.curr_client[0] + self.ddl_index = self.curr_client[1] self.doc_index = self.curr_client[2] else: raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}") self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json') + self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json') self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json') model_name = config.get('embedding_model', 'all-MiniLM-L6-v2') @@ -98,6 +103,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: self._save_metadata(self.sql_metadata, 'sql_metadata.json') return entry_id + def add_ddl(self, ddl: str, **kwargs) -> str: + entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl}) + self._save_index(self.ddl_index, 'ddl_index.faiss') + self._save_metadata(self.ddl_metadata, 'ddl_metadata.json') + return entry_id + def add_documentation(self, documentation: str, **kwargs) -> str: entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation}) self._save_index(self.doc_index, 'doc_index.faiss') @@ -112,6 +123,9 @@ def _get_similar(self, index, metadata_list, text, n_results) -> list: def get_similar_question_sql(self, question: str, **kwargs) -> list: return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql) + def get_related_ddl(self, question: str, **kwargs) -> list: + return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)] + def get_related_documentation(self, question: str, **kwargs) -> list: return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)] @@ -119,14 +133,18 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: sql_data = pd.DataFrame(self.sql_metadata) sql_data['training_data_type'] = 'sql' + ddl_data = pd.DataFrame(self.ddl_metadata) + ddl_data['training_data_type'] = 'ddl' + doc_data = pd.DataFrame(self.doc_metadata) doc_data['training_data_type'] = 'documentation' - return pd.concat([sql_data, doc_data], ignore_index=True) + return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True) def remove_training_data(self, id: str, **kwargs) -> bool: for metadata_list, index, index_name in [ (self.sql_metadata, self.sql_index, 'sql_index.faiss'), + (self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'), (self.doc_metadata, self.doc_index, 'doc_index.faiss') ]: for i, item in enumerate(metadata_list): @@ -146,7 +164,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool: return False def remove_collection(self, collection_name: str) -> bool: - if collection_name in ["sql", "documentation"]: + if collection_name in ["sql", "ddl", "documentation"]: setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)) setattr(self, f"{collection_name}_metadata", []) From fd8a9280c426f5810dd153a47c1564365fda6f5d Mon Sep 17 00:00:00 2001 From: Mohammed Abbadi Date: Thu, 5 Sep 2024 10:49:33 +0300 Subject: [PATCH 3/3] Update faiss.py --- src/vanna/faiss/faiss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vanna/faiss/faiss.py b/src/vanna/faiss/faiss.py index 12e15ec1..58726a55 100644 --- a/src/vanna/faiss/faiss.py +++ b/src/vanna/faiss/faiss.py @@ -84,14 +84,14 @@ def _save_metadata(self, metadata, filename): with open(filepath, 'w') as f: json.dump(metadata, f) - def _generate_embedding(self, data: str, **kwargs) -> List[float]: + def generate_embedding(self, data: str, **kwargs) -> List[float]: embedding = self.embedding_model.encode(data) assert embedding.shape[0] == self.embedding_dim, \ f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}" return embedding.tolist() def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str: - embedding = self._generate_embedding(text) + embedding = self.generate_embedding(text) index.add(np.array([embedding], dtype=np.float32)) entry_id = str(uuid.uuid4()) metadata_list.append({"id": entry_id, **(extra_metadata or {})}) @@ -116,7 +116,7 @@ def add_documentation(self, documentation: str, **kwargs) -> str: return entry_id def _get_similar(self, index, metadata_list, text, n_results) -> list: - embedding = self._generate_embedding(text) + embedding = self.generate_embedding(text) D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results) return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]] @@ -151,7 +151,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool: if item['id'] == id: del metadata_list[i] new_index = faiss.IndexFlatL2(self.embedding_dim) - embeddings = [self._generate_embedding(json.dumps(m)) for m in metadata_list] + embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list] if embeddings: new_index.add(np.array(embeddings, dtype=np.float32)) setattr(self, index_name.split('.')[0], new_index)