Skip to content

Commit

Permalink
add pure embedding retrieve api (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
IANTHEREAL authored Oct 8, 2024
1 parent 1af82ae commit 3b026e6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
11 changes: 11 additions & 0 deletions backend/app/api/admin_routes/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,14 @@ async def retrieve_documents(
) -> List[Document]:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service.retrieve(question, top_k=top_k)

@router.get("/admin/embedding_retrieve")
async def embedding_retrieve(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
chat_engine: str = "default",
top_k: Optional[int] = 5,
) -> List[Document]:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service._embedding_retrieve(question, top_k=top_k)
19 changes: 19 additions & 0 deletions backend/app/rag/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,25 @@ def _retrieve(self, question: str, top_k: int) -> List[Document]:

return source_documents

def _embedding_retrieve(self, question: str, top_k: int) -> List[Document]:
_embed_model = get_default_embedding_model(self.db_session)

vector_store = TiDBVectorStore(session=self.db_session)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=_embed_model,
)

retrieve_engine = vector_index.as_retriever(
node_postprocessors=[self._reranker],
similarity_top_k=top_k,
)

node_list: List[NodeWithScore] = retrieve_engine.retrieve(question)
source_documents = self._get_source_documents(node_list)

return source_documents

def _get_source_documents(self, node_list: List[NodeWithScore]) -> List[Document]:
source_nodes_ids = [s_n.node_id for s_n in node_list]
stmt = select(Document).where(
Expand Down

0 comments on commit 3b026e6

Please sign in to comment.