Skip to content

Commit

Permalink
Filter on distance (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-epsilla authored Feb 12, 2024
1 parent 306b535 commit b878fef
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
6 changes: 3 additions & 3 deletions engine/db/execution/vec_search_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ bool VecSearchExecutor::BruteForceSearch(
num_result = 0;
// remove the invalid entries
for (; iter < end - start; ++iter) {
if (!deleted.test(iter + start) && expr_evaluator.LogicalEvaluate(root_node_index, iter + start)) {
if (!deleted.test(iter + start) && expr_evaluator.LogicalEvaluate(root_node_index, iter + start, brute_force_queue_[iter].distance_)) {
if (iter != num_result) {
brute_force_queue_[num_result] = brute_force_queue_[iter];
}
Expand Down Expand Up @@ -833,7 +833,7 @@ Status VecSearchExecutor::Search(
result_size = 0;
for (int64_t k_i = 0; k_i < candidateNum && result_size < searchLimit; ++k_i) {
auto id = set_L_[k_i + master_queue_start].id_;
if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id)) {
if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id, set_L_[k_i + master_queue_start].distance_)) {
continue;
}
search_result_[result_size] = set_L_[k_i + master_queue_start].id_;
Expand All @@ -846,7 +846,7 @@ Status VecSearchExecutor::Search(
const int64_t master_queue_start = local_queues_starts_[num_threads_ - 1];
for (int64_t k_i = 0; k_i < candidateNum && result_size < searchLimit; ++k_i) {
auto id = set_L_[k_i + master_queue_start].id_;
if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id)) {
if (deleted.test(id) || !expr_evaluator.LogicalEvaluate(filter_root_index, id, set_L_[k_i + master_queue_start].distance_)) {
continue;
}
search_result_[result_size] = set_L_[k_i + master_queue_start].id_;
Expand Down
18 changes: 17 additions & 1 deletion engine/query/expr/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ enum class State {
Number,
String,
Attribute,
Operator,
Operator
};

bool isArithChar(char c) {
Expand Down Expand Up @@ -111,6 +111,14 @@ Status SplitTokens(std::string& expression, std::vector<std::string>& tokens) {
state = State::String;
} else if (c == '&' || c == '|' || c == '^') {
return Status(NOT_IMPLEMENTED_ERROR, "Epsilla does not support bitwise operators yet.");
} else if (c == '@') {
if (i + 9 <= last_index && expression.substr(i, 9) == "@distance") {
state = State::Attribute;
cur_token = "@distance";
i += 9;
} else {
return Status(INVALID_EXPR, "Filter expression is not valid.");
}
} else {
return Status(INVALID_EXPR, "Filter expression is not valid.");
}
Expand Down Expand Up @@ -247,6 +255,10 @@ bool isIntConstant(const std::string& str) {
return std::regex_match(str, integerPattern);
};

bool isDistance(const std::string& str) {
return str == "@distance";
};

bool isDoubleConstant(const std::string& str) {
std::regex doublePattern("^[-+]?\\d+\\.\\d+(?:[eE][-+]?\\d+)?$");

Expand Down Expand Up @@ -414,6 +426,10 @@ Status GenerateNodes(
node->node_type = NodeType::DoubleConst;
node->value_type = ValueType::DOUBLE;
node->double_value = std::stod(token);
} else if (isDistance(token)) {
node->field_name = token;
node->node_type = NodeType::DoubleAttr;
node->value_type = ValueType::DOUBLE;
} else {
if (field_map.find(token) == field_map.end()) {
return Status(INVALID_EXPR, "Invalid filter expression: field name '" + token + "' not found.");
Expand Down
21 changes: 14 additions & 7 deletions engine/query/expr/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::string ExprEvaluator::StrEvaluate(const int& node_index, const int64_t& can
return "";
}

double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind) {
double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind, const double distance) {
ExprNodePtr root = nodes_[node_index];
auto node_type = root->node_type;
if (node_type == NodeType::IntConst) {
Expand All @@ -107,10 +107,13 @@ double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind
return GetIntFieldValue(name, cand_ind, node_type);
} else if (node_type == NodeType::DoubleAttr || node_type == NodeType::FloatAttr) {
auto name = root->field_name;
if (name == "@distance") {
return distance;
}
return GetRealNumberFieldValue(name, cand_ind, node_type);
} else if (root->left != -1 && root->right != -1) {
auto left = NumEvaluate(root->left, cand_ind);
auto right = NumEvaluate(root->right, cand_ind);
auto left = NumEvaluate(root->left, cand_ind, distance);
auto right = NumEvaluate(root->right, cand_ind, distance);
switch (node_type) {
case NodeType::Add:
return left + right;
Expand All @@ -128,6 +131,10 @@ double ExprEvaluator::NumEvaluate(const int& node_index, const int64_t& cand_ind
}

bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_ind) {
return LogicalEvaluate(node_index, cand_ind, 0);
}

bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_ind, const double distance) {
if (node_index < 0) {
return true;
}
Expand Down Expand Up @@ -159,17 +166,17 @@ bool ExprEvaluator::LogicalEvaluate(const int& node_index, const int64_t& cand_i
auto right = LogicalEvaluate(right_index, cand_ind);
return node_type == NodeType::EQ ? left == right : left != right;
} else {
auto left = NumEvaluate(left_index, cand_ind);
auto right = NumEvaluate(right_index, cand_ind);
auto left = NumEvaluate(left_index, cand_ind, distance);
auto right = NumEvaluate(right_index, cand_ind, distance);
return node_type == NodeType::EQ ? left == right : left != right;
}
} else if (node_type == NodeType::AND || node_type == NodeType::OR) {
auto left = LogicalEvaluate(left_index, cand_ind);
auto right = LogicalEvaluate(right_index, cand_ind);
return node_type == NodeType::AND ? (left && right) : (left || right);
} else {
auto left = NumEvaluate(left_index, cand_ind);
auto right = NumEvaluate(right_index, cand_ind);
auto left = NumEvaluate(left_index, cand_ind, distance);
auto right = NumEvaluate(right_index, cand_ind, distance);
switch (node_type) {
case NodeType::GT:
return left > right;
Expand Down
3 changes: 2 additions & 1 deletion engine/query/expr/expr_evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ class ExprEvaluator {
~ExprEvaluator();

bool LogicalEvaluate(const int& node_index, const int64_t& cand_ind);
bool LogicalEvaluate(const int& node_index, const int64_t& cand_ind, const double distance);

private:
std::string GetStrFieldValue(const std::string& field_name, const int64_t& cand_ind);
bool GetBoolFieldValue(const std::string& field_name, const int64_t& cand_ind);
int64_t GetIntFieldValue(const std::string& field_name, const int64_t& cand_ind, NodeType& node_type);
double GetRealNumberFieldValue(const std::string& field_name, const int64_t& cand_ind, NodeType& node_type);
std::string StrEvaluate(const int& node_index, const int64_t& cand_ind);
double NumEvaluate(const int& node_index, const int64_t& cand_ind);
double NumEvaluate(const int& node_index, const int64_t& cand_ind, const double distance);

public:
std::vector<ExprNodePtr>& nodes_;
Expand Down

0 comments on commit b878fef

Please sign in to comment.