Skip to content

Commit

Permalink
feat: support simple metadata post filter (#201)
Browse files Browse the repository at this point in the history
New config to enable filters:

```json
{
  "knowledge_graph": {
    "relationship_meta_filters": {
      "product": "tidbcloud"
    },
  },
  "vector_search": {
    "metadata_post_filters": {
      "filters": [
        {
          "key": "product",
          "value": "tidbcloud"
        }
      ]
    }
  }
}
```
  • Loading branch information
Mini256 authored Aug 1, 2024
1 parent e3028ef commit 9737d20
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 7 deletions.
9 changes: 7 additions & 2 deletions backend/app/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.chat_engine_config.get_db_chat_engine()
self._reranker = self.chat_engine_config.get_reranker(db_session)
self._metadata_filter = self.chat_engine_config.get_metadata_filter()

def chat(
self, chat_messages: List[ChatMessage], chat_id: Optional[UUID] = None
Expand Down Expand Up @@ -220,7 +221,10 @@ def _get_langfuse_callback_manager():
},
) as event:
result = graph_index.intent_based_search(
user_question, chat_history, include_meta=True
user_question,
chat_history,
include_meta=True,
relationship_meta_filters=kg_config.relationship_meta_filters,
)
event.on_end(payload={"graph": result["queries"]})

Expand All @@ -239,6 +243,7 @@ def _get_langfuse_callback_manager():
depth=kg_config.depth,
include_meta=kg_config.include_meta,
with_degree=kg_config.with_degree,
relationship_meta_filters=kg_config.relationship_meta_filters,
with_chunks=False,
)
graph_knowledges = get_prompt_by_jinja2_template(
Expand Down Expand Up @@ -306,7 +311,7 @@ def _get_langfuse_callback_manager():
)
query_engine = vector_index.as_query_engine(
llm=_llm,
node_postprocessors=[self._reranker],
node_postprocessors=[self._metadata_filter, self._reranker],
streaming=True,
text_qa_template=text_qa_template,
refine_template=refine_template,
Expand Down
15 changes: 15 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from google.oauth2 import service_account
from google.auth.transport.requests import Request

from app.rag.node_postprocessor import MetadataPostFilter
from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters
from app.types import LLMProvider, EmbeddingProvider
from app.rag.default_prompt import (
DEFAULT_INTENT_GRAPH_KNOWLEDGE,
Expand Down Expand Up @@ -45,17 +47,23 @@ class LLMOption(BaseModel):
refine_prompt: str = DEFAULT_REFINE_PROMPT


class VectorSearchOption(BaseModel):
metadata_post_filters: Optional[MetadataFilters] = None,


class KnowledgeGraphOption(BaseModel):
enabled: bool = True
depth: int = 2
include_meta: bool = True
with_degree: bool = False
using_intent_search: bool = True
relationship_meta_filters: Optional[dict] = None


class ChatEngineConfig(BaseModel):
llm: LLMOption = LLMOption()
knowledge_graph: KnowledgeGraphOption = KnowledgeGraphOption()
vector_search: VectorSearchOption = VectorSearchOption()

_db_chat_engine: Optional[DBChatEngine] = None
_db_llm: Optional[DBLLM] = None
Expand Down Expand Up @@ -114,6 +122,9 @@ def get_fast_dspy_lm(self, session: Session) -> dspy.LM:
def get_reranker(self, session: Session) -> BaseNodePostprocessor:
return get_default_reranker(session)

def get_metadata_filter(self) -> BaseNodePostprocessor:
return get_metadata_post_filter(self.vector_search.metadata_post_filters)

def screenshot(self) -> dict:
return self.model_dump_json(
exclude={
Expand Down Expand Up @@ -214,3 +225,7 @@ def get_default_reranker(session: Session) -> BaseNodePostprocessor:
model="jina-reranker-v2-base-multilingual",
top_n=10,
)


def get_metadata_post_filter(filters: Optional[MetadataFilters] = None) -> BaseNodePostprocessor:
return MetadataPostFilter(filters)
2 changes: 2 additions & 0 deletions backend/app/rag/knowledge_graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def intent_based_search(
chat_history: list = [],
depth: int = 2,
include_meta: bool = False,
relationship_meta_filters: Dict = {},
) -> Mapping[str, Any]:
chat_content = query
if len(chat_history) > 0:
Expand Down Expand Up @@ -281,6 +282,7 @@ def process_query(sub_query):
depth,
include_meta,
with_chunks=False,
relationship_meta_filters=relationship_meta_filters,
session=tmp_session,
)
except Exception as exc:
Expand Down
3 changes: 3 additions & 0 deletions backend/app/rag/node_postprocessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .metadata_post_filter import MetadataPostFilter

__all__ = ["MetadataPostFilter"]
120 changes: 120 additions & 0 deletions backend/app/rag/node_postprocessor/metadata_post_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
from enum import Enum
from typing import List, Optional, Any, Union, Annotated

from llama_index.core import QueryBundle
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel


class FilterOperator(str, Enum):
"""Vector store filter operator."""

# TODO add more operators
EQ = "==" # default operator (string, int, float)
GT = ">" # greater than (int, float)
LT = "<" # less than (int, float)
NE = "!=" # not equal to (string, int, float)
GTE = ">=" # greater than or equal to (int, float)
LTE = "<=" # less than or equal to (int, float)
IN = "in" # In array (string or number)
NIN = "nin" # Not in array (string or number)
ANY = "any" # Contains any (array of strings)
ALL = "all" # Contains all (array of strings)
TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase
# within the text field)
CONTAINS = "contains" # metadata array contains value (string or number)


class FilterCondition(str, Enum):
AND = "and"
OR = "or"


class MetadataFilter(BaseModel):
key: str
value: Union[
int,
float,
str,
List[int],
List[float],
List[str],
]
operator: FilterOperator = FilterOperator.EQ


# Notice:
#
# llama index is still heavily using pydantic v1 to define data models. Using classes in llama index to define FastAPI
# parameters may cause the following errors:
#
# TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given
#
# See: https://github.com/run-llama/llama_index/issues/14807#issuecomment-2241285940
class MetadataFilters(BaseModel):
"""Metadata filters for vector stores."""

# Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc.
filters: List[Union[MetadataFilter, "MetadataFilters"]]
# and/or such conditions for combining different filters
condition: Optional[FilterCondition] = FilterCondition.AND


_logger = logging.getLogger(__name__)


class MetadataPostFilter(BaseNodePostprocessor):
filters: Optional[MetadataFilters] = None

def __init__(
self,
filters: Optional[MetadataFilters] = None,
**kwargs: Any
):
super().__init__(**kwargs)
self.filters = filters

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if self.filters is None:
return nodes

filtered_nodes = []
for node in nodes:
# TODO: support advanced post filtering.
if self.match_all_filters(node.node):
filtered_nodes.append(node)
return filtered_nodes

def match_all_filters(self, node: Any) -> bool:
if self.filters is None or not isinstance(node, MetadataFilters):
return True

if self.filters.condition != FilterCondition.AND:
_logger.warning(
f"Advanced filtering is not supported yet. "
f"Filter condition {self.filters.condition} is ignored."
)
return True

for f in self.filters.filters:
if f.key not in node.extra_info:
return False

if f.operator is not None and f.operator != FilterOperator.EQ:
_logger.warning(
f"Advanced filtering is not supported yet. "
f"Filter operator {f.operator} is ignored."
)
return True

value = node.extra_info[f.key]
if f.value != value:
return False

return True
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"celery>=5.4.0",
"dspy-ai>=2.4.9",
"langfuse==2.36.2",
"llama-index>=0.10.57",
"llama-index>=0.10.59",
"alembic>=1.13.1",
"pydantic>=2.8.2",
"pydantic-settings>=2.3.3",
Expand Down
5 changes: 3 additions & 2 deletions backend/requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

aiohttp==3.9.5
# via datasets
Expand Down Expand Up @@ -278,13 +279,13 @@ langsmith==0.1.88
# via langchain-core
llama-cloud==0.0.10
# via llama-index-indices-managed-llama-cloud
llama-index==0.10.57
llama-index==0.10.59
llama-index-agent-openai==0.2.7
# via llama-index
# via llama-index-program-openai
llama-index-cli==0.1.12
# via llama-index
llama-index-core==0.10.57
llama-index-core==0.10.59
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
Expand Down
5 changes: 3 additions & 2 deletions backend/requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

aiohttp==3.9.5
# via datasets
Expand Down Expand Up @@ -278,13 +279,13 @@ langsmith==0.1.88
# via langchain-core
llama-cloud==0.0.10
# via llama-index-indices-managed-llama-cloud
llama-index==0.10.57
llama-index==0.10.59
llama-index-agent-openai==0.2.7
# via llama-index
# via llama-index-program-openai
llama-index-cli==0.1.12
# via llama-index
llama-index-core==0.10.57
llama-index-core==0.10.59
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
Expand Down

0 comments on commit 9737d20

Please sign in to comment.