Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hybrid search #2353

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backend/modules/brain/rags/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
When answering use markdown to make it concise and neat.
Use the following pieces of context from files provided by the user that are store in a brain to answer the users question in the same language as the user question. Your name is Quivr. You're a helpful assistant.
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
The relevance of the context is ranked from 0 to 2. 2 being the most relevant and 0 being the least relevant. Value more relevant information when answering.
User instruction to follow if provided to answer: {custom_instructions}
"""

Expand All @@ -65,7 +66,7 @@
# How we format documents

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="File: {file_name} Content: {page_content}"
template="File: {file_name} Content: {page_content} Relevance: {similarity}"
)


Expand Down Expand Up @@ -226,6 +227,7 @@ def get_chain(self):
| CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
| StrOutputParser(),
"question": lambda x: x["question"],
}

prompt_custom_user = self.prompt_to_use()
Expand All @@ -236,7 +238,7 @@ def get_chain(self):
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever_doc,
"question": lambda x: x["standalone_question"],
"question": itemgetter("question"),
"custom_instructions": lambda x: prompt_to_use,
}

Expand Down
4 changes: 3 additions & 1 deletion backend/packages/files/parsers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ async def process_file(
for doc in file.documents: # pyright: ignore reportPrivateUsage=none
new_metadata = metadata.copy()
len_chunk = len(enc.encode(doc.page_content))
page_content_encoded = doc.page_content.encode("unicode_escape").decode(
page_content_encoded = doc.page_content.replace("\n", " ")
page_content_encoded = page_content_encoded.encode("unicode_escape").decode(
"ascii", "replace"
)
# Replace \n with space

new_metadata["chunk_size"] = len_chunk
doc_with_metadata = DocumentSerializable(
Expand Down
18 changes: 16 additions & 2 deletions backend/vectorstore/supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,28 @@ def find_brain_closest_query(
def similarity_search(
self,
query: str,
full_text_weight: float = 2.0, # Add this parameter
semantic_weight: float = 1.0, # Add this parameter
rrf_k: int = 1, # Add this parameter
k: int = 40,
table: str = "match_vectors",
table: str = "hybrid_match_vectors",
threshold: float = 0.5,
**kwargs: Any,
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
query_lower = query.lower()
res = self._client.rpc(
table,
{
"query_text": query_lower,
"match_count": 500,
"query_embedding": query_embedding,
"max_chunk_sum": self.max_input,
"full_text_weight": full_text_weight, # Add this line
"semantic_weight": semantic_weight, # Add this line
"rrf_k": rrf_k, # Add this line
"p_brain_id": str(self.brain_id),
"max_chunk_sum": self.max_input,
},
).execute()

Expand All @@ -96,5 +105,10 @@ def similarity_search(
for search in res.data
if search.get("content")
]
for search in res.data:
if search.get("content"):
logger.info("ft_rank: %s", search.get("ft_rank", 0.0))
logger.info("similarity: %s", search.get("similarity", 0.0))
logger.info("rank_ix: %s", search.get("rank_ix", 0))

return match_result
86 changes: 86 additions & 0 deletions supabase/migrations/20240316075202_hybrid.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
alter table "public"."vectors" add column "fts" tsvector generated always as (to_tsvector('english'::regconfig, content)) stored;

CREATE INDEX vectors_fts_idx ON public.vectors USING gin (fts);

set check_function_bodies = off;

CREATE OR REPLACE FUNCTION public.hybrid_match_vectors(query_text text, query_embedding vector, p_brain_id uuid, match_count integer, max_chunk_sum integer, full_text_weight double precision DEFAULT 1.0, semantic_weight double precision DEFAULT 1.0, rrf_k integer DEFAULT 50)
RETURNS TABLE(id uuid, brain_id uuid, content text, metadata jsonb, embedding vector, similarity double precision, ft_rank double precision, rank_ix integer)
LANGUAGE plpgsql
AS $function$
BEGIN
RETURN QUERY
WITH full_text AS (
SELECT
v.id,
ts_rank_cd(v.fts, websearch_to_tsquery(query_text))::double precision AS ft_rank,
row_number() OVER (ORDER BY ts_rank_cd(v.fts, websearch_to_tsquery(query_text)) DESC)::integer AS rank_ix,
(v.metadata->>'chunk_size')::integer AS chunk_size
FROM
vectors v
INNER JOIN
brains_vectors bv ON v.id = bv.vector_id
WHERE
bv.brain_id = p_brain_id AND
v.fts @@ websearch_to_tsquery(query_text)
LIMIT LEAST(match_count, 30) * 2
), semantic AS (
SELECT
v.id,
(1 - (v.embedding <#> query_embedding))::double precision AS semantic_similarity,
row_number() OVER (ORDER BY (v.embedding <#> query_embedding))::integer AS rank_ix
FROM
vectors v
INNER JOIN
brains_vectors bv ON v.id = bv.vector_id
WHERE
bv.brain_id = p_brain_id
LIMIT LEAST(match_count, 30) * 2
), combined AS (
SELECT
coalesce(ft.id, st.id) AS id,
(coalesce(1.0 / (rrf_k + ft.rank_ix), 0)::double precision * full_text_weight + coalesce(1.0 / (rrf_k + st.rank_ix), 0)::double precision * semantic_weight)::double precision AS combined_score,
ft.ft_rank,
ft.rank_ix,
ft.chunk_size
FROM
full_text ft
FULL OUTER JOIN
semantic st ON ft.id = st.id
), ranked_vectors AS (
SELECT
c.id,
c.combined_score,
sum(c.chunk_size) OVER (ORDER BY c.combined_score DESC, c.rank_ix)::integer AS running_total,
c.ft_rank,
c.rank_ix,
c.chunk_size
FROM
combined c
)
SELECT
v.id,
bv.brain_id,
v.content,
v.metadata,
v.embedding,
c.combined_score::double precision AS similarity,
c.ft_rank::double precision,
c.rank_ix::integer
FROM
ranked_vectors c
JOIN
vectors v ON v.id = c.id
JOIN
brains_vectors bv ON v.id = bv.vector_id
WHERE
c.running_total <= max_chunk_sum
ORDER BY
c.combined_score DESC, c.rank_ix
LIMIT
LEAST(match_count, 30);
END;
$function$
;