From 175e9f425de3ca213a31cbae51101e9871ea9c35 Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Thu, 5 Sep 2024 17:06:35 +0200 Subject: [PATCH] fix so that we actually use distance metrics that match the semantics provided by usearch, add more tests, optimize min_by aggregates, use proper expression matcher, make index scan binding more accurate in case there are multiple indexes, fix misc bugs --- duckdb | 2 +- extension-ci-tools | 2 +- src/hnsw/CMakeLists.txt | 5 +- src/hnsw/hnsw_index.cpp | 51 +++- src/hnsw/hnsw_index_physical_create.cpp | 25 +- src/hnsw/hnsw_optimize_expr.cpp | 98 +++++++ ..._index_scan.cpp => hnsw_optimize_scan.cpp} | 129 +++++----- src/hnsw/hnsw_optimize_topk.cpp | 239 ++++++++++++++++++ src/hnsw/hnsw_topk_operator.cpp | 9 + src/include/hnsw/hnsw.hpp | 12 +- src/include/hnsw/hnsw_index.hpp | 3 +- test/sql/hnsw/hnsw_metrics.test | 4 +- test/sql/hnsw/hnsw_rewrite.test | 31 +++ test/sql/hnsw/hnsw_topk.test | 24 ++ 14 files changed, 534 insertions(+), 100 deletions(-) create mode 100644 src/hnsw/hnsw_optimize_expr.cpp rename src/hnsw/{hnsw_plan_index_scan.cpp => hnsw_optimize_scan.cpp} (63%) create mode 100644 src/hnsw/hnsw_optimize_topk.cpp create mode 100644 src/hnsw/hnsw_topk_operator.cpp create mode 100644 test/sql/hnsw/hnsw_rewrite.test create mode 100644 test/sql/hnsw/hnsw_topk.test diff --git a/duckdb b/duckdb index 19a3247..dffc4ff 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit 19a32473166e0ad0d7472142a6f93872178c5fcf +Subproject commit dffc4ffad7d9cb7c181db87b1bfb51e261bcedf6 diff --git a/extension-ci-tools b/extension-ci-tools index c924560..638a972 160000 --- a/extension-ci-tools +++ b/extension-ci-tools @@ -1 +1 @@ -Subproject commit c9245601b70dba971b7d9a516c6a68fe5986ae00 +Subproject commit 638a97210d162f6133fea31c6b524c516d10e515 diff --git a/src/hnsw/CMakeLists.txt b/src/hnsw/CMakeLists.txt index 8bd40f2..f69ee88 100644 --- a/src/hnsw/CMakeLists.txt +++ b/src/hnsw/CMakeLists.txt @@ -8,6 +8,9 @@ set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_create.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_topk_operator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_topk.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_expr.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_scan.cpp PARENT_SCOPE ) \ No newline at end of file diff --git a/src/hnsw/hnsw_index.cpp b/src/hnsw/hnsw_index.cpp index 63108be..c7fba13 100644 --- a/src/hnsw/hnsw_index.cpp +++ b/src/hnsw/hnsw_index.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/execution/index/fixed_size_allocator.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/planner/operator/logical_get.hpp" #include "hnsw/hnsw.hpp" namespace duckdb { @@ -227,23 +228,19 @@ string HNSWIndex::GetMetric() const { } } -bool HNSWIndex::IsDistanceFunction(const string &distance_function_name) { - auto accepted_functions = {"array_distance", "array_cosine_similarity", "array_inner_product"}; - return std::find(accepted_functions.begin(), accepted_functions.end(), distance_function_name) != - accepted_functions.end(); -} - -bool HNSWIndex::MatchesDistanceFunction(const string &distance_function_name) const { - if (distance_function_name == "array_distance" && +bool HNSWIndex::MatchesDistanceFunction(const string &name) const { + if ((name == "array_distance" || name == "<->") && index.metric().metric_kind() == unum::usearch::metric_kind_t::l2sq_k) { + // Note: usearch uses l2sq, for their metric, but its functionally equivalent to sqrt(l2sq) return true; } - if (distance_function_name == "array_cosine_similarity" && + if ((name == "array_cosine_distance" || name == "<=>") && index.metric().metric_kind() == unum::usearch::metric_kind_t::cos_k) { return true; } - if (distance_function_name == "array_inner_product" && + if ((name == "array_negative_inner_product" || name == "<#>") && index.metric().metric_kind() == unum::usearch::metric_kind_t::ip_k) { + // Note: usearch uses (1.0 - ip) for their metric, but its functionally equivalent to (-ip) return true; } return false; @@ -536,6 +533,40 @@ void HNSWIndex::VerifyAllocations(IndexLock &state) { throw NotImplementedException("HNSWIndex::VerifyAllocations() not implemented"); } +//------------------------------------------------------------------------------ +// Can rewrite index expression? +//------------------------------------------------------------------------------ +static void RewriteIndexExpression(const Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible, + bool &any_column_ref) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + any_column_ref = true; + auto &bound_colref = expr.Cast(); + // bound column ref: rewrite to fit in the current set of bound column ids + bound_colref.binding.table_index = get.table_index; + auto &column_ids = index.GetColumnIds(); + auto &get_column_ids = get.GetColumnIds(); + column_t referenced_column = column_ids[bound_colref.binding.column_index]; + // search for the referenced column in the set of column_ids + for (idx_t i = 0; i < get_column_ids.size(); i++) { + if (get_column_ids[i] == referenced_column) { + bound_colref.binding.column_index = i; + return; + } + } + // column id not found in bound columns in the LogicalGet: rewrite not possible + rewrite_possible = false; + } + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &child) { RewriteIndexExpression(index, get, child, rewrite_possible, any_column_ref); }); +} + +bool HNSWIndex::CanRewriteIndexExpression(LogicalGet &get, Expression &column_ref) const { + bool rewrite_possible = true; + bool any_column_ref = false; + RewriteIndexExpression(*this, get, column_ref, rewrite_possible, any_column_ref); + return any_column_ref && rewrite_possible; +} + //------------------------------------------------------------------------------ // Register Index Type //------------------------------------------------------------------------------ diff --git a/src/hnsw/hnsw_index_physical_create.cpp b/src/hnsw/hnsw_index_physical_create.cpp index 1d2d63c..543b042 100644 --- a/src/hnsw/hnsw_index_physical_create.cpp +++ b/src/hnsw/hnsw_index_physical_create.cpp @@ -14,13 +14,13 @@ namespace duckdb { -PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table, +PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table_p, const vector &column_ids, unique_ptr info, vector> unbound_expressions, idx_t estimated_cardinality) // Declare this operators as a EXTENSION operator : PhysicalOperator(PhysicalOperatorType::EXTENSION, op.types, estimated_cardinality), - table(table.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), + table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), sorted(false) { // convert virtual column ids to storage column ids @@ -34,7 +34,8 @@ PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatal //------------------------------------------------------------- class CreateHNSWIndexGlobalState final : public GlobalSinkState { public: - CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) {} + CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) { + } const PhysicalOperator &op; //! Global index to be added to the table @@ -262,21 +263,17 @@ class HNSWIndexConstructionEvent final : public BasePipelineEvent { // Create the index entry in the catalog auto &schema = table.schema; info.column_ids = storage_ids; - const auto index_entry = schema.CreateIndex(*gstate.context, info, table).get(); - if (!index_entry) { - D_ASSERT(info.on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT); - // index already exists, but error ignored because of IF NOT EXISTS - // return SinkFinalizeType::READY; - return; + + if (schema.GetEntry(schema.GetCatalogTransaction(*gstate.context), CatalogType::INDEX_ENTRY, info.index_name)) { + if (info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) { + throw CatalogException("Index with name \"%s\" already exists", info.index_name); + } } - // Get the entry as a DuckIndexEntry + const auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(*gstate.context), info, table).get(); + D_ASSERT(index_entry); auto &duck_index = index_entry->Cast(); duck_index.initial_index_size = gstate.global_index->Cast().GetInMemorySize(); - duck_index.info = make_uniq(storage.GetDataTableInfo(), duck_index.name); - for (auto &parsed_expr : info.parsed_expressions) { - duck_index.parsed_expressions.push_back(parsed_expr->Copy()); - } // Finally add it to storage storage.AddIndex(std::move(gstate.global_index)); diff --git a/src/hnsw/hnsw_optimize_expr.cpp b/src/hnsw/hnsw_optimize_expr.cpp new file mode 100644 index 0000000..a05f1a6 --- /dev/null +++ b/src/hnsw/hnsw_optimize_expr.cpp @@ -0,0 +1,98 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/optimizer.hpp" + +#include "hnsw/hnsw.hpp" +#include "hnsw/hnsw_index.hpp" + +namespace duckdb { + +//------------------------------------------------------------------------------ +// Rewrite rules +//------------------------------------------------------------------------------ +// This optimizer rewrites expressions of the form: +// (1.0 - array_cosine_similarity) => (array_cosine_distance) +// (-array_inner_product) => (array_negative_inner_product) + +class CosineDistanceRule final : public Rule { +public: + explicit CosineDistanceRule(ExpressionRewriter &rewriter); + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +CosineDistanceRule::CosineDistanceRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto func = make_uniq(); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::UNORDERED; + func->function = make_uniq("array_cosine_similarity"); + + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->matchers[0]->type = make_uniq(LogicalType::FLOAT); + op->matchers.push_back(std::move(func)); + op->policy = SetMatcher::Policy::ORDERED; + op->function = make_uniq("-"); + op->type = make_uniq(LogicalType::FLOAT); + + root = std::move(op); +} + +unique_ptr CosineDistanceRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + // auto &root_expr = bindings[0].get().Cast(); + const auto &const_expr = bindings[1].get().Cast(); + auto &similarity_expr = bindings[2].get().Cast(); + + if (!const_expr.value.IsNull() && const_expr.value.GetValue() == 1.0) { + // Create the new array_cosine_distance function + vector> args; + vector arg_types; + arg_types.push_back(similarity_expr.children[0]->return_type); + arg_types.push_back(similarity_expr.children[1]->return_type); + args.push_back(std::move(similarity_expr.children[0])); + args.push_back(std::move(similarity_expr.children[1])); + + auto &context = GetContext(); + auto func_entry = Catalog::GetEntry(context, "", "", "array_cosine_distance", + OnEntryNotFound::RETURN_NULL); + + if (!func_entry) { + return nullptr; + } + + changes_made = true; + auto func = func_entry->functions.GetFunctionByArguments(context, arg_types); + return make_uniq(similarity_expr.return_type, func, std::move(args), nullptr); + } + return nullptr; +} + +//------------------------------------------------------------------------------ +// Optimizer +//------------------------------------------------------------------------------ +class HNSWExprOptimizer : public OptimizerExtension { +public: + HNSWExprOptimizer() { + optimize_function = Optimize; + } + + static void Optimize(OptimizerExtensionInput &input, unique_ptr &plan) { + ExpressionRewriter rewriter(input.context); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.VisitOperator(*plan); + } +}; + +void HNSWModule::RegisterExprOptimizer(DatabaseInstance &db) { + // Register the TopKOptimizer + db.config.optimizer_extensions.push_back(HNSWExprOptimizer()); +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/hnsw/hnsw_plan_index_scan.cpp b/src/hnsw/hnsw_optimize_scan.cpp similarity index 63% rename from src/hnsw/hnsw_plan_index_scan.cpp rename to src/hnsw/hnsw_optimize_scan.cpp index d6eb36a..4e9f547 100644 --- a/src/hnsw/hnsw_plan_index_scan.cpp +++ b/src/hnsw/hnsw_optimize_scan.cpp @@ -12,9 +12,35 @@ #include "hnsw/hnsw_index_scan.hpp" #include "duckdb/optimizer/remove_unused_columns.hpp" #include "duckdb/planner/expression_iterator.hpp" - +#include "duckdb/optimizer/matcher/expression_matcher.hpp" namespace duckdb { +//----------------------------------------------------------------------------- +// Matcher +//----------------------------------------------------------------------------- + +// bindings[0] = distance function +// bindings[1] = column reference +// bindings[2] = vector constant +static bool MatchDistanceFunction(vector> &bindings, Expression &distance_expr, + Expression &column_ref, idx_t vector_size) { + + unordered_set distance_functions = { + "array_distance", "<->", "array_cosine_distance", "<=>", "array_negative_inner_product", "<#>"}; + + auto distance_matcher = make_uniq(); + distance_matcher->function = make_uniq(distance_functions); + distance_matcher->expr_type = make_uniq(ExpressionType::BOUND_FUNCTION); + distance_matcher->policy = SetMatcher::Policy::UNORDERED; + distance_matcher->matchers.push_back(make_uniq(column_ref)); + + auto vector_matcher = make_uniq(); + vector_matcher->type = make_uniq(LogicalType::ARRAY(LogicalType::FLOAT, vector_size)); + distance_matcher->matchers.push_back(std::move(vector_matcher)); // The vector to match + + return distance_matcher->Match(distance_expr, bindings); +} + //----------------------------------------------------------------------------- // Plan rewriter //----------------------------------------------------------------------------- @@ -32,7 +58,6 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { return false; } - // Look for a expression that is a distance expression auto &top_n = op.Cast(); if (top_n.orders.size() != 1) { @@ -40,7 +65,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { return false; } - auto &order = top_n.orders[0]; + const auto &order = top_n.orders[0]; if (order.type != OrderType::ASCENDING) { // We can only optimize if the order by expression is ascending @@ -51,69 +76,26 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { // The expression has to reference the child operator (a projection with the distance function) return false; } - auto &bound_column_ref = order.expression->Cast(); + const auto &bound_column_ref = order.expression->Cast(); // find the expression that is referenced - auto &immediate_child = top_n.children[0]; - if (immediate_child->type != LogicalOperatorType::LOGICAL_PROJECTION) { + if (top_n.children.size() != 1 || top_n.children.front()->type != LogicalOperatorType::LOGICAL_PROJECTION) { // The child has to be a projection return false; } - auto &projection = immediate_child->Cast(); - auto projection_index = bound_column_ref.binding.column_index; - if (projection.expressions[projection_index]->type != ExpressionType::BOUND_FUNCTION) { - // The expression has to be a function - return false; - } - auto &bound_function = projection.expressions[projection_index]->Cast(); - if (!HNSWIndex::IsDistanceFunction(bound_function.function.name)) { - // We can only optimize if the order by expression is a distance function - return false; - } + auto &projection = top_n.children.front()->Cast(); - // Figure out the query vector - Value target_value; - if (bound_function.children[0]->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - target_value = bound_function.children[0]->Cast().value; - } else if (bound_function.children[1]->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - target_value = bound_function.children[1]->Cast().value; - } else { - // We can only optimize if one of the children is a constant - return false; - } - - // TODO: We should check that the other argument to the distance function is a column reference - // that matches the column that the index is on. That also helps us identify the scan operator + // This the expression that is referenced by the order by expression + const auto projection_index = bound_column_ref.binding.column_index; + const auto &projection_expr = projection.expressions[projection_index]; - auto value_type = target_value.type(); - if (value_type.id() != LogicalTypeId::ARRAY) { - // We can only optimize if the constant is an array + // The projection must sit on top of a get + if (projection.children.size() != 1 || projection.children.front()->type != LogicalOperatorType::LOGICAL_GET) { return false; } - auto array_size = ArrayType::GetSize(value_type); - auto array_inner_type = ArrayType::GetChildType(value_type); - if (array_inner_type.id() != LogicalTypeId::FLOAT) { - // Try to cast to float - bool ok = target_value.DefaultTryCastAs(LogicalType::ARRAY(LogicalType::FLOAT, array_size), true); - if (!ok) { - // We can only optimize if the array is of floats or we can cast it to floats - return false; - } - } - - // find any direct child or grandchild that is a get - auto child = top_n.children[0].get(); - while (child->type != LogicalOperatorType::LOGICAL_GET) { - // TODO: Handle joins? - if (child->children.size() != 1) { - // Either 0 or more than 1 child - return false; - } - child = child->children[0].get(); - } - auto &get = child->Cast(); + auto &get = projection.children.front()->Cast(); // Check if the get is a table scan if (get.function.name != "seq_scan") { return false; @@ -134,28 +116,39 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { // Find the index unique_ptr bind_data = nullptr; - table_info.GetIndexes().BindAndScan(context, table_info, [&](HNSWIndex &index_entry) { - auto &hnsw_index = index_entry.Cast(); + vector> bindings; + + table_info.GetIndexes().BindAndScan(context, table_info, [&](HNSWIndex &hnsw_index) { + // Check that the HNSW index actually indexes the expression + const auto index_expr = hnsw_index.unbound_expressions[0]->Copy(); + if (!hnsw_index.CanRewriteIndexExpression(get, *index_expr)) { + return false; + } + + const auto vector_size = hnsw_index.GetVectorSize(); + + // Reset the bindings + bindings.clear(); - if (hnsw_index.GetVectorSize() != array_size) { - // The vector size of the index does not match the vector size of the query + if (!MatchDistanceFunction(bindings, *projection_expr, *index_expr, vector_size)) { + // The expression is not a distance function return false; } - if (!hnsw_index.MatchesDistanceFunction(bound_function.function.name)) { + const auto &distance_func = bindings[0].get().Cast(); + if (!hnsw_index.MatchesDistanceFunction(distance_func.function.name)) { // The distance function of the index does not match the distance function of the query return false; } - // Create a query vector from the constant value - auto query_vector = make_unsafe_uniq_array(array_size); - auto vector_elements = ArrayValue::GetChildren(target_value); - for (idx_t i = 0; i < array_size; i++) { + const auto &matched_vector = bindings[2].get().Cast().value; + auto query_vector = make_unsafe_uniq_array(vector_size); + auto vector_elements = ArrayValue::GetChildren(matched_vector); + for (idx_t i = 0; i < vector_size; i++) { query_vector[i] = vector_elements[i].GetValue(); } - // Create the bind data for this index - bind_data = make_uniq(duck_table, index_entry, top_n.limit, std::move(query_vector)); + bind_data = make_uniq(duck_table, hnsw_index, top_n.limit, std::move(query_vector)); return true; }); @@ -167,7 +160,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { // Replace the scan with our custom index scan function get.function = HNSWIndexScanFunction::GetFunction(); - auto cardinality = get.function.cardinality(context, bind_data.get()); + const auto cardinality = get.function.cardinality(context, bind_data.get()); get.has_estimated_cardinality = cardinality->has_estimated_cardinality; get.estimated_cardinality = cardinality->estimated_cardinality; get.bind_data = std::move(bind_data); @@ -242,7 +235,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { //----------------------------------------------------------------------------- // Register //----------------------------------------------------------------------------- -void HNSWModule::RegisterPlanIndexScan(DatabaseInstance &db) { +void HNSWModule::RegisterScanOptimizer(DatabaseInstance &db) { // Register the optimizer extension db.config.optimizer_extensions.push_back(HNSWIndexScanOptimizer()); } diff --git a/src/hnsw/hnsw_optimize_topk.cpp b/src/hnsw/hnsw_optimize_topk.cpp new file mode 100644 index 0000000..2bebc27 --- /dev/null +++ b/src/hnsw/hnsw_optimize_topk.cpp @@ -0,0 +1,239 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" + +#include "hnsw/hnsw.hpp" +#include "hnsw/hnsw_index.hpp" +#include "hnsw/hnsw_index_scan.hpp" + +namespace duckdb { + +//------------------------------------------------------------------------------ +// Optimizer Helpers +//------------------------------------------------------------------------------ + +class AggregateFunctionExpressionMatcher : public ExpressionMatcher { +public: + AggregateFunctionExpressionMatcher() + : ExpressionMatcher(ExpressionClass::BOUND_AGGREGATE), policy(SetMatcher::Policy::INVALID) { + } + //! The matchers for the child expressions + vector> matchers; + //! The set matcher matching policy to use + SetMatcher::Policy policy; + //! The function name to match + unique_ptr function; + + bool Match(Expression &expr_p, vector> &bindings) override; +}; + +bool AggregateFunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + if (!FunctionMatcher::Match(function, expr.function.name)) { + return false; + } + if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + return false; + } + return true; +} + +static unique_ptr CreateListOrderByExpr(ClientContext &context, unique_ptr elem_expr, + unique_ptr order_expr, + unique_ptr filter_expr) { + auto func_entry = + Catalog::GetEntry(context, "", "", "list", OnEntryNotFound::RETURN_NULL); + if (!func_entry) { + return nullptr; + } + + auto func = func_entry->functions.GetFunctionByOffset(0); + vector> arguments; + arguments.push_back(std::move(elem_expr)); + + auto agg_bind_data = func.bind(context, func, arguments); + auto new_agg_expr = + make_uniq(func, std::move(arguments), std::move(std::move(filter_expr)), + std::move(agg_bind_data), AggregateType::NON_DISTINCT); + + // We also need to order the list items by the distance + BoundOrderByNode order_by_node(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(order_expr)); + new_agg_expr->order_bys = make_uniq(); + new_agg_expr->order_bys->orders.push_back(std::move(order_by_node)); + + return new_agg_expr; +} + +// bindings[0] = the aggregate_function +// bindings[1] = the column ref +// bindings[2] = the distance function +// bindings[3] = the arg ref +// bindings[4] = the matched vector +// bindings[5] = the k value +static bool MatchDistanceFunction(vector> &bindings, Expression &agg_expr, Expression &column_ref, + idx_t vector_size) { + + AggregateFunctionExpressionMatcher min_by_matcher; + min_by_matcher.function = make_uniq("min_by"); + min_by_matcher.policy = SetMatcher::Policy::ORDERED; + + unordered_set distance_functions = { + "array_distance", "<->", "array_cosine_distance", "<=>", "array_negative_inner_product", "<#>"}; + + auto distance_matcher = make_uniq(); + distance_matcher->function = make_uniq(distance_functions); + distance_matcher->expr_type = make_uniq(ExpressionType::BOUND_FUNCTION); + distance_matcher->policy = SetMatcher::Policy::UNORDERED; + distance_matcher->matchers.push_back(make_uniq(column_ref)); + + auto vector_matcher = make_uniq(); + vector_matcher->type = make_uniq(LogicalType::ARRAY(LogicalType::FLOAT, vector_size)); + distance_matcher->matchers.push_back(std::move(vector_matcher)); // The vector to match + + min_by_matcher.matchers.push_back(make_uniq()); // Dont care about the column + min_by_matcher.matchers.push_back(std::move(distance_matcher)); + min_by_matcher.matchers.push_back(make_uniq()); // The k value + + return min_by_matcher.Match(agg_expr, bindings); +} + +//------------------------------------------------------------------------------ +// Main Optimizer +//------------------------------------------------------------------------------ +// This optimizer rewrites +// +// AGG(MIN_BY(t1.col1, distance_func(t1.col2, query_vector), k)) <- TABLE_SCAN(t1) +// => +// AGG(LIST(col1 ORDER BY distance_func(col2, query_vector) ASC)) <- HNSW_INDEX_SCAN(t1, query_vector, k) +// + +class HNSWTopKOptimizer : public OptimizerExtension { +public: + HNSWTopKOptimizer() { + optimize_function = Optimize; + } + + static bool TryOptimize(Binder &binder, ClientContext &context, unique_ptr &plan) { + // Look for a Aggregate operator + if (plan->type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return false; + } + // Look for a expression that is a distance expression + auto &agg = plan->Cast(); + if (!agg.groups.empty() || agg.expressions.size() != 1) { + return false; + } + + // we need the aggregate to be on top of a projection + if (agg.children.size() != 1) { + return false; + } + + // we also need the projection to be directly on top of a table scan that has a hnsw index + if (agg.children[0]->type != LogicalOperatorType::LOGICAL_GET) { + return false; + } + + auto &get = agg.children[0]->Cast(); + if (get.function.name != "seq_scan") { + return false; + } + + // Get the table + auto &table = *get.GetTable(); + if (!table.IsDuckTable()) { + return false; + } + + auto &duck_table = table.Cast(); + auto &table_info = *table.GetStorage().GetDataTableInfo(); + + unique_ptr bind_data = nullptr; + vector> bindings; + + table_info.GetIndexes().BindAndScan(context, table_info, [&](HNSWIndex &hnsw_index) { + // Check that the HNSW index actually indexes the expression + const auto index_expr = hnsw_index.unbound_expressions[0]->Copy(); + if (!hnsw_index.CanRewriteIndexExpression(get, *index_expr)) { + return false; + } + + const auto vector_size = hnsw_index.GetVectorSize(); + + // Reset the bindings + bindings.clear(); + if (!MatchDistanceFunction(bindings, *agg.expressions[0], *index_expr, vector_size)) { + return false; + } + // bindings[0] = the aggregate_function + // bindings[1] = the column ref + // bindings[2] = the distance function + // bindings[3] = the arg ref + // bindings[4] = the matched vector + // bindings[5] = the k value + + const auto &distance_func = bindings[2].get().Cast(); + if (!hnsw_index.MatchesDistanceFunction(distance_func.function.name)) { + return false; + } + + const auto &matched_vector = bindings[4].get().Cast().value; + auto query_vector = make_unsafe_uniq_array(vector_size); + auto vector_elements = ArrayValue::GetChildren(matched_vector); + for (idx_t i = 0; i < vector_size; i++) { + query_vector[i] = vector_elements[i].GetValue(); + } + + const auto k_limit = bindings[5].get().Cast().value.GetValue(); + bind_data = make_uniq(duck_table, hnsw_index, k_limit, std::move(query_vector)); + + return true; + }); + + if (!bind_data) { + // No index found + return false; + } + + const auto &agg_expr = bindings[0].get().Cast(); + const auto &col_expr = bindings[1].get(); + const auto &distance_func = bindings[2].get().Cast(); + + // Replace the aggregate with a index scan + projection + get.function = HNSWIndexScanFunction::GetFunction(); + const auto cardinality = get.function.cardinality(context, bind_data.get()); + get.has_estimated_cardinality = cardinality->has_estimated_cardinality; + get.estimated_cardinality = cardinality->estimated_cardinality; + get.bind_data = std::move(bind_data); + + // Replace the aggregate with a list() aggregate function ordered by the distance + agg.expressions[0] = CreateListOrderByExpr(context, col_expr.Copy(), distance_func.Copy(), + agg_expr.filter ? agg_expr.filter->Copy() : nullptr); + return true; + } + + static void Optimize(OptimizerExtensionInput &input, unique_ptr &plan) { + if (!TryOptimize(input.optimizer.binder, input.context, plan)) { + // Recursively optimize the children + for (auto &child : plan->children) { + Optimize(input, child); + } + } + } +}; + +void HNSWModule::RegisterTopKOptimizer(DatabaseInstance &db) { + // Register the TopKOptimizer + db.config.optimizer_extensions.push_back(HNSWTopKOptimizer()); +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/hnsw/hnsw_topk_operator.cpp b/src/hnsw/hnsw_topk_operator.cpp new file mode 100644 index 0000000..62f7200 --- /dev/null +++ b/src/hnsw/hnsw_topk_operator.cpp @@ -0,0 +1,9 @@ +#include "hnsw/hnsw.hpp" + +namespace duckdb { + +void HNSWModule::RegisterTopKOperator(DatabaseInstance &db) { + // Register the TopKOperator +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/include/hnsw/hnsw.hpp b/src/include/hnsw/hnsw.hpp index 6247e4f..a94e71e 100644 --- a/src/include/hnsw/hnsw.hpp +++ b/src/include/hnsw/hnsw.hpp @@ -10,18 +10,26 @@ struct HNSWModule { RegisterIndex(db); RegisterIndexScan(db); RegisterIndexPragmas(db); - RegisterPlanIndexScan(db); RegisterPlanIndexCreate(db); RegisterMacros(db); + + // Optimizers + RegisterExprOptimizer(db); + RegisterScanOptimizer(db); + RegisterTopKOptimizer(db); } private: static void RegisterIndex(DatabaseInstance &db); static void RegisterIndexScan(DatabaseInstance &db); static void RegisterIndexPragmas(DatabaseInstance &db); - static void RegisterPlanIndexScan(DatabaseInstance &db); static void RegisterPlanIndexCreate(DatabaseInstance &db); static void RegisterMacros(DatabaseInstance &db); + static void RegisterTopKOptimizer(DatabaseInstance &db); + + static void RegisterExprOptimizer(DatabaseInstance &db); + static void RegisterTopKOperator(DatabaseInstance &db); + static void RegisterScanOptimizer(DatabaseInstance &db); }; } // namespace duckdb \ No newline at end of file diff --git a/src/include/hnsw/hnsw_index.hpp b/src/include/hnsw/hnsw_index.hpp index 82e527e..506c4b5 100644 --- a/src/include/hnsw/hnsw_index.hpp +++ b/src/include/hnsw/hnsw_index.hpp @@ -45,7 +45,6 @@ class HNSWIndex : public BoundIndex { idx_t Scan(IndexScanState &state, Vector &result); idx_t GetVectorSize() const; - static bool IsDistanceFunction(const string &distance_function_name); bool MatchesDistanceFunction(const string &distance_function_name) const; string GetMetric() const; @@ -102,6 +101,8 @@ class HNSWIndex : public BoundIndex { index_size = index.size(); } + bool CanRewriteIndexExpression(LogicalGet &get, Expression &column_ref) const; + private: bool is_dirty = false; StorageLock rwlock; diff --git a/test/sql/hnsw/hnsw_metrics.test b/test/sql/hnsw/hnsw_metrics.test index 87b35e1..f7c7c5e 100644 --- a/test/sql/hnsw/hnsw_metrics.test +++ b/test/sql/hnsw/hnsw_metrics.test @@ -21,12 +21,12 @@ CREATE INDEX my_l2sq_idx ON t1 USING HNSW (vec) WITH (metric = 'l2sq'); # Make sure we get the index scan plan on the index matching the distance measurement query II -EXPLAIN SELECT array_inner_product(vec, [1,2,3]::FLOAT[3]) as x FROM t1 ORDER BY x LIMIT 3; +EXPLAIN SELECT array_negative_inner_product(vec, [1,2,3]::FLOAT[3]) as x FROM t1 ORDER BY x LIMIT 3; ---- physical_plan :.*HNSW_INDEX_SCAN.*my_ip_idx.* query II -EXPLAIN SELECT array_cosine_similarity(vec, [1,2,3]::FLOAT[3]) as x FROM t1 ORDER BY x LIMIT 3; +EXPLAIN SELECT array_cosine_distance(vec, [1,2,3]::FLOAT[3]) as x FROM t1 ORDER BY x LIMIT 3; ---- physical_plan :.*HNSW_INDEX_SCAN.*my_cos_idx.* diff --git a/test/sql/hnsw/hnsw_rewrite.test b/test/sql/hnsw/hnsw_rewrite.test new file mode 100644 index 0000000..2074873 --- /dev/null +++ b/test/sql/hnsw/hnsw_rewrite.test @@ -0,0 +1,31 @@ +require vss + +# Test that we rewrite (1 - array_cosine_similarity) to array_cosine_distance + +statement ok +CREATE TABLE t1 (v FLOAT[3]); + +statement ok +INSERT INTO t1 VALUES ([0.8, 0.8, 0.8]); + +query II +EXPLAIN SELECT 1.0 - array_cosine_similarity(v, [0.2,0.2,0.2]::FLOAT[3]) FROM t1; +---- +physical_plan :.*array_cosine_distance.* + +statement ok +pragma disable_optimizer; + +query II +EXPLAIN SELECT 1.0 - array_cosine_similarity(v, [0.2,0.2,0.2]::FLOAT[3]) FROM t1; +---- +physical_plan :.*array_cosine_similarity.* + +query I rowsort RES +SELECT 1.0 - array_cosine_similarity(v, [0.2,0.2,0.2]::FLOAT[3]) FROM t1; + +statement ok +pragma enable_optimizer; + +query I rowsort RES +SELECT 1.0 - array_cosine_similarity(v, [0.2,0.2,0.2]::FLOAT[3]) FROM t1; \ No newline at end of file diff --git a/test/sql/hnsw/hnsw_topk.test b/test/sql/hnsw/hnsw_topk.test new file mode 100644 index 0000000..5b83969 --- /dev/null +++ b/test/sql/hnsw/hnsw_topk.test @@ -0,0 +1,24 @@ +require vss + +require noforcestorage + +statement ok +CREATE TABLE t1 (vec FLOAT[3]); + +statement ok +INSERT INTO t1 SELECT array_value(a,b,c) FROM range(1,10) ra(a), range(1,10) rb(b), range(1,10) rc(c); + +statement ok +CREATE INDEX my_idx ON t1 USING HNSW (vec); + +# Make sure we get the index scan plan +query II +EXPLAIN SELECT min_by(vec, array_distance(vec, [1,2,3]::FLOAT[3]), 3) as x FROM t1; +---- +physical_plan :.*HNSW_INDEX_SCAN.* + +query I +SELECT min_by(vec, array_distance(vec, [5,5,5]::FLOAT[3]), 3) as x FROM t1; +---- +[[5.0, 5.0, 5.0], [6.0, 5.0, 5.0], [5.0, 6.0, 5.0]] +