Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-euclidean distance metrics + more optimizer rules #27

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ CREATE INDEX my_hnsw_cosine_index ON my_vector_table USING HNSW (vec) WITH (metr

The following table shows the supported distance metrics and their corresponding DuckDB functions

| Description | Metric | Function |
| --- | --- | --- |
| Euclidean distance | `l2sq` | `array_distance` |
| Cosine similarity | `cosine` | `array_cosine_similarity` |
| Inner product | `ip` | `array_inner_product` |
| Description | Metric | Function |
| --- | --- |--------------------------------|
| Euclidean distance | `l2sq` | `array_distance` |
| Cosine similarity | `cosine` | `array_cosine_distance` |
| Inner product | `ip` | `array_negative_inner_product` |

## Inserts, Updates, Deletes and Re-Compaction

Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 475 files
5 changes: 4 additions & 1 deletion src/hnsw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_topk_operator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_expr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_scan.cpp
PARENT_SCOPE
)
51 changes: 41 additions & 10 deletions src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/common/serializer/binary_serializer.hpp"
#include "duckdb/execution/index/fixed_size_allocator.hpp"
#include "duckdb/storage/table/scan_state.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "hnsw/hnsw.hpp"

namespace duckdb {
Expand Down Expand Up @@ -227,23 +228,19 @@ string HNSWIndex::GetMetric() const {
}
}

bool HNSWIndex::IsDistanceFunction(const string &distance_function_name) {
auto accepted_functions = {"array_distance", "array_cosine_similarity", "array_inner_product"};
return std::find(accepted_functions.begin(), accepted_functions.end(), distance_function_name) !=
accepted_functions.end();
}

bool HNSWIndex::MatchesDistanceFunction(const string &distance_function_name) const {
if (distance_function_name == "array_distance" &&
bool HNSWIndex::MatchesDistanceFunction(const string &name) const {
if ((name == "array_distance" || name == "<->") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::l2sq_k) {
// Note: usearch uses l2sq, for their metric, but its functionally equivalent to sqrt(l2sq)
return true;
}
if (distance_function_name == "array_cosine_similarity" &&
if ((name == "array_cosine_distance" || name == "<=>") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::cos_k) {
return true;
}
if (distance_function_name == "array_inner_product" &&
if ((name == "array_negative_inner_product" || name == "<#>") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::ip_k) {
// Note: usearch uses (1.0 - ip) for their metric, but its functionally equivalent to (-ip)
return true;
}
return false;
Expand Down Expand Up @@ -536,6 +533,40 @@ void HNSWIndex::VerifyAllocations(IndexLock &state) {
throw NotImplementedException("HNSWIndex::VerifyAllocations() not implemented");
}

//------------------------------------------------------------------------------
// Can rewrite index expression?
//------------------------------------------------------------------------------
static void RewriteIndexExpression(const Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible,
bool &any_column_ref) {
if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
any_column_ref = true;
auto &bound_colref = expr.Cast<BoundColumnRefExpression>();
// bound column ref: rewrite to fit in the current set of bound column ids
bound_colref.binding.table_index = get.table_index;
auto &column_ids = index.GetColumnIds();
auto &get_column_ids = get.GetColumnIds();
column_t referenced_column = column_ids[bound_colref.binding.column_index];
// search for the referenced column in the set of column_ids
for (idx_t i = 0; i < get_column_ids.size(); i++) {
if (get_column_ids[i] == referenced_column) {
bound_colref.binding.column_index = i;
return;
}
}
// column id not found in bound columns in the LogicalGet: rewrite not possible
rewrite_possible = false;
}
ExpressionIterator::EnumerateChildren(
expr, [&](Expression &child) { RewriteIndexExpression(index, get, child, rewrite_possible, any_column_ref); });
}

bool HNSWIndex::CanRewriteIndexExpression(LogicalGet &get, Expression &column_ref) const {
bool rewrite_possible = true;
bool any_column_ref = false;
RewriteIndexExpression(*this, get, column_ref, rewrite_possible, any_column_ref);
return any_column_ref && rewrite_possible;
}

//------------------------------------------------------------------------------
// Register Index Type
//------------------------------------------------------------------------------
Expand Down
25 changes: 11 additions & 14 deletions src/hnsw/hnsw_index_physical_create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

namespace duckdb {

PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table,
PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table_p,
const vector<column_t> &column_ids, unique_ptr<CreateIndexInfo> info,
vector<unique_ptr<Expression>> unbound_expressions,
idx_t estimated_cardinality)
// Declare this operators as a EXTENSION operator
: PhysicalOperator(PhysicalOperatorType::EXTENSION, op.types, estimated_cardinality),
table(table.Cast<DuckTableEntry>()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)),
table(table_p.Cast<DuckTableEntry>()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)),
sorted(false) {

// convert virtual column ids to storage column ids
Expand All @@ -34,7 +34,8 @@ PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatal
//-------------------------------------------------------------
class CreateHNSWIndexGlobalState final : public GlobalSinkState {
public:
CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) {}
CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) {
}

const PhysicalOperator &op;
//! Global index to be added to the table
Expand Down Expand Up @@ -262,21 +263,17 @@ class HNSWIndexConstructionEvent final : public BasePipelineEvent {
// Create the index entry in the catalog
auto &schema = table.schema;
info.column_ids = storage_ids;
const auto index_entry = schema.CreateIndex(*gstate.context, info, table).get();
if (!index_entry) {
D_ASSERT(info.on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT);
// index already exists, but error ignored because of IF NOT EXISTS
// return SinkFinalizeType::READY;
return;

if (schema.GetEntry(schema.GetCatalogTransaction(*gstate.context), CatalogType::INDEX_ENTRY, info.index_name)) {
if (info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) {
throw CatalogException("Index with name \"%s\" already exists", info.index_name);
}
}

// Get the entry as a DuckIndexEntry
const auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(*gstate.context), info, table).get();
D_ASSERT(index_entry);
auto &duck_index = index_entry->Cast<DuckIndexEntry>();
duck_index.initial_index_size = gstate.global_index->Cast<BoundIndex>().GetInMemorySize();
duck_index.info = make_uniq<IndexDataTableInfo>(storage.GetDataTableInfo(), duck_index.name);
for (auto &parsed_expr : info.parsed_expressions) {
duck_index.parsed_expressions.push_back(parsed_expr->Copy());
}

// Finally add it to storage
storage.AddIndex(std::move(gstate.global_index));
Expand Down
98 changes: 98 additions & 0 deletions src/hnsw/hnsw_optimize_expr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/expression_iterator.hpp"
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
#include "duckdb/optimizer/column_binding_replacer.hpp"
#include "duckdb/optimizer/optimizer.hpp"

#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"

namespace duckdb {

//------------------------------------------------------------------------------
// Rewrite rules
//------------------------------------------------------------------------------
// This optimizer rewrites expressions of the form:
// (1.0 - array_cosine_similarity) => (array_cosine_distance)
// (-array_inner_product) => (array_negative_inner_product)

class CosineDistanceRule final : public Rule {
public:
explicit CosineDistanceRule(ExpressionRewriter &rewriter);
unique_ptr<Expression> Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, bool &changes_made,
bool is_root) override;
};

CosineDistanceRule::CosineDistanceRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
auto func = make_uniq<FunctionExpressionMatcher>();
func->matchers.push_back(make_uniq<ExpressionMatcher>());
func->matchers.push_back(make_uniq<ExpressionMatcher>());
func->policy = SetMatcher::Policy::UNORDERED;
func->function = make_uniq<SpecificFunctionMatcher>("array_cosine_similarity");

auto op = make_uniq<FunctionExpressionMatcher>();
op->matchers.push_back(make_uniq<ConstantExpressionMatcher>());
op->matchers[0]->type = make_uniq<SpecificTypeMatcher>(LogicalType::FLOAT);
op->matchers.push_back(std::move(func));
op->policy = SetMatcher::Policy::ORDERED;
op->function = make_uniq<SpecificFunctionMatcher>("-");
op->type = make_uniq<SpecificTypeMatcher>(LogicalType::FLOAT);

root = std::move(op);
}

unique_ptr<Expression> CosineDistanceRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
bool &changes_made, bool is_root) {
// auto &root_expr = bindings[0].get().Cast<BoundFunctionExpression>();
const auto &const_expr = bindings[1].get().Cast<BoundConstantExpression>();
auto &similarity_expr = bindings[2].get().Cast<BoundFunctionExpression>();

if (!const_expr.value.IsNull() && const_expr.value.GetValue<float>() == 1.0) {
// Create the new array_cosine_distance function
vector<unique_ptr<Expression>> args;
vector<LogicalType> arg_types;
arg_types.push_back(similarity_expr.children[0]->return_type);
arg_types.push_back(similarity_expr.children[1]->return_type);
args.push_back(std::move(similarity_expr.children[0]));
args.push_back(std::move(similarity_expr.children[1]));

auto &context = GetContext();
auto func_entry = Catalog::GetEntry<ScalarFunctionCatalogEntry>(context, "", "", "array_cosine_distance",
OnEntryNotFound::RETURN_NULL);

if (!func_entry) {
return nullptr;
}

changes_made = true;
auto func = func_entry->functions.GetFunctionByArguments(context, arg_types);
return make_uniq<BoundFunctionExpression>(similarity_expr.return_type, func, std::move(args), nullptr);
}
return nullptr;
}

//------------------------------------------------------------------------------
// Optimizer
//------------------------------------------------------------------------------
class HNSWExprOptimizer : public OptimizerExtension {
public:
HNSWExprOptimizer() {
optimize_function = Optimize;
}

static void Optimize(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
ExpressionRewriter rewriter(input.context);
rewriter.rules.push_back(make_uniq<CosineDistanceRule>(rewriter));
rewriter.VisitOperator(*plan);
}
};

void HNSWModule::RegisterExprOptimizer(DatabaseInstance &db) {
// Register the TopKOptimizer
db.config.optimizer_extensions.push_back(HNSWExprOptimizer());
}

} // namespace duckdb
Loading
Loading