Skip to content

Commit

Permalink
Update faiss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
m7mdhka committed Sep 5, 2024
1 parent 5917e49 commit fd8a928
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/vanna/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def _save_metadata(self, metadata, filename):
with open(filepath, 'w') as f:
json.dump(metadata, f)

def _generate_embedding(self, data: str, **kwargs) -> List[float]:
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_model.encode(data)
assert embedding.shape[0] == self.embedding_dim, \
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
return embedding.tolist()

def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
embedding = self._generate_embedding(text)
embedding = self.generate_embedding(text)
index.add(np.array([embedding], dtype=np.float32))
entry_id = str(uuid.uuid4())
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
Expand All @@ -116,7 +116,7 @@ def add_documentation(self, documentation: str, **kwargs) -> str:
return entry_id

def _get_similar(self, index, metadata_list, text, n_results) -> list:
embedding = self._generate_embedding(text)
embedding = self.generate_embedding(text)
D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]

Expand Down Expand Up @@ -151,7 +151,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
if item['id'] == id:
del metadata_list[i]
new_index = faiss.IndexFlatL2(self.embedding_dim)
embeddings = [self._generate_embedding(json.dumps(m)) for m in metadata_list]
embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list]
if embeddings:
new_index.add(np.array(embeddings, dtype=np.float32))
setattr(self, index_name.split('.')[0], new_index)
Expand Down

0 comments on commit fd8a928

Please sign in to comment.