diff --git a/engine/db/execution/vec_search_executor.cpp b/engine/db/execution/vec_search_executor.cpp index 2b30edef..c499ec8b 100644 --- a/engine/db/execution/vec_search_executor.cpp +++ b/engine/db/execution/vec_search_executor.cpp @@ -751,7 +751,7 @@ bool VecSearchExecutor::BruteForceSearch( num_result = 0; // remove the invalid entries for (; iter < end - start; ++iter) { - if (!deleted.test(iter + start) && expr_evaluator.LogicalEvaluate(root_node_index, iter + start)) { + if (!deleted.test(iter + start) && expr_evaluator.LogicalEvaluate(root_node_index, iter + start, brute_force_queue_[iter].distance_)) { if (iter != num_result) { brute_force_queue_[num_result] = brute_force_queue_[iter]; } @@ -833,7 +833,7 @@ Status VecSearchExecutor::Search( result_size = 0; for (int64_t k_i = 0; k_i < candidateNum && result_size < searchLimit; ++k_i) { auto id = set_L_[k_i + master_queue_start].id_; - if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id)) { + if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id, set_L_[k_i + master_queue_start].distance_)) { continue; } search_result_[result_size] = set_L_[k_i + master_queue_start].id_; @@ -846,7 +846,7 @@ Status VecSearchExecutor::Search( const int64_t master_queue_start = local_queues_starts_[num_threads_ - 1]; for (int64_t k_i = 0; k_i < candidateNum && result_size < searchLimit; ++k_i) { auto id = set_L_[k_i + master_queue_start].id_; - if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id)) { + if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id, set_L_[k_i + master_queue_start].distance_)) { continue; } search_result_[result_size] = set_L_[k_i + master_queue_start].id_; diff --git a/engine/query/expr/expr.cpp b/engine/query/expr/expr.cpp index 891fe3d6..1f12405d 100644 --- a/engine/query/expr/expr.cpp +++ b/engine/query/expr/expr.cpp @@ -18,7 +18,7 @@ enum class State { Number, String, Attribute, - Operator, + Operator }; bool isArithChar(char c) { @@ -111,6 +111,14 @@ Status SplitTokens(std::string& expression, std::vector& tokens) { state = State::String; } else if (c == '&' || c == '|' || c == '^') { return Status(NOT_IMPLEMENTED_ERROR, "Epsilla does not support bitwise operators yet."); + } else if (c == '@') { + if (i + 9 <= last_index && expression.substr(i, 9) == "@distance") { + state = State::Attribute; + cur_token = "@distance"; + i += 9; + } else { + return Status(INVALID_EXPR, "Filter expression is not valid."); + } } else { return Status(INVALID_EXPR, "Filter expression is not valid."); } @@ -247,6 +255,10 @@ bool isIntConstant(const std::string& str) { return std::regex_match(str, integerPattern); }; +bool isDistance(const std::string& str) { + return str == "@distance"; +}; + bool isDoubleConstant(const std::string& str) { std::regex doublePattern("^[-+]?\\d+\\.\\d+(?:[eE][-+]?\\d+)?$"); @@ -414,6 +426,10 @@ Status GenerateNodes( node->node_type = NodeType::DoubleConst; node->value_type = ValueType::DOUBLE; node->double_value = std::stod(token); + } else if (isDistance(token)) { + node->field_name = token; + node->node_type = NodeType::DoubleAttr; + node->value_type = ValueType::DOUBLE; } else { if (field_map.find(token) == field_map.end()) { return Status(INVALID_EXPR, "Invalid filter expression: field name '" + token + "' not found."); diff --git a/engine/query/expr/expr_evaluator.cpp b/engine/query/expr/expr_evaluator.cpp index 4747e332..4d003e53 100644 --- a/engine/query/expr/expr_evaluator.cpp +++ b/engine/query/expr/expr_evaluator.cpp @@ -91,7 +91,7 @@ std::string ExprEvaluator::StrEvaluate(const int& node_index, const int64_t& can return ""; } -double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind) { +double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind, const double distance) { ExprNodePtr root = nodes_[node_index]; auto node_type = root->node_type; if (node_type == NodeType::IntConst) { @@ -107,10 +107,13 @@ double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind return GetIntFieldValue(name, cand_ind, node_type); } else if (node_type == NodeType::DoubleAttr || node_type == NodeType::FloatAttr) { auto name = root->field_name; + if (name == "@distance") { + return distance; + } return GetRealNumberFieldValue(name, cand_ind, node_type); } else if (root->left != -1 && root->right != -1) { - auto left = NumEvaluate(root->left, cand_ind); - auto right = NumEvaluate(root->right, cand_ind); + auto left = NumEvaluate(root->left, cand_ind, distance); + auto right = NumEvaluate(root->right, cand_ind, distance); switch (node_type) { case NodeType::Add: return left + right; @@ -128,6 +131,10 @@ double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind } bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_ind) { + return LogicalEvaluate(node_index, cand_ind, 0); +} + +bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_ind, const double distance) { if (node_index < 0) { return true; } @@ -159,8 +166,8 @@ bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_i auto right = LogicalEvaluate(right_index, cand_ind); return node_type == NodeType::EQ ? left == right : left != right; } else { - auto left = NumEvaluate(left_index, cand_ind); - auto right = NumEvaluate(right_index, cand_ind); + auto left = NumEvaluate(left_index, cand_ind, distance); + auto right = NumEvaluate(right_index, cand_ind, distance); return node_type == NodeType::EQ ? left == right : left != right; } } else if (node_type == NodeType::AND || node_type == NodeType::OR) { @@ -168,8 +175,8 @@ bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_i auto right = LogicalEvaluate(right_index, cand_ind); return node_type == NodeType::AND ? (left && right) : (left || right); } else { - auto left = NumEvaluate(left_index, cand_ind); - auto right = NumEvaluate(right_index, cand_ind); + auto left = NumEvaluate(left_index, cand_ind, distance); + auto right = NumEvaluate(right_index, cand_ind, distance); switch (node_type) { case NodeType::GT: return left > right; diff --git a/engine/query/expr/expr_evaluator.hpp b/engine/query/expr/expr_evaluator.hpp index bf8d6d9f..f6e7fd5b 100644 --- a/engine/query/expr/expr_evaluator.hpp +++ b/engine/query/expr/expr_evaluator.hpp @@ -24,6 +24,7 @@ class ExprEvaluator { ~ExprEvaluator(); bool LogicalEvaluate(const int& node_index, const int64_t& cand_ind); + bool LogicalEvaluate(const int& node_index, const int64_t& cand_ind, const double distance); private: std::string GetStrFieldValue(const std::string& field_name, const int64_t& cand_ind); @@ -31,7 +32,7 @@ class ExprEvaluator { int64_t GetIntFieldValue(const std::string& field_name, const int64_t& cand_ind, NodeType& node_type); double GetRealNumberFieldValue(const std::string& field_name, const int64_t& cand_ind, NodeType& node_type); std::string StrEvaluate(const int& node_index, const int64_t& cand_ind); - double NumEvaluate(const int& node_index, const int64_t& cand_ind); + double NumEvaluate(const int& node_index, const int64_t& cand_ind, const double distance); public: std::vector& nodes_;