Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Sep 26, 2024
1 parent 685a131 commit 7ae0e10
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 138 deletions.
97 changes: 48 additions & 49 deletions src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LinkedBlockReader {

public:
LinkedBlockReader(FixedSizeAllocator &allocator, IndexPointer root_pointer)
: allocator(allocator), root_pointer(root_pointer), current_pointer(root_pointer), position_in_block(0) {
: allocator(allocator), root_pointer(root_pointer), current_pointer(root_pointer), position_in_block(0) {
}

void Reset() {
Expand Down Expand Up @@ -77,7 +77,7 @@ class LinkedBlockWriter {

public:
LinkedBlockWriter(FixedSizeAllocator &allocator, IndexPointer root_pointer)
: allocator(allocator), root_pointer(root_pointer), current_pointer(root_pointer), position_in_block(0) {
: allocator(allocator), root_pointer(root_pointer), current_pointer(root_pointer), position_in_block(0) {
}

void ClearCurrentBlock() {
Expand Down Expand Up @@ -119,10 +119,10 @@ class LinkedBlockWriter {

// Constructor
HNSWIndex::HNSWIndex(const string &name, IndexConstraintType index_constraint_type, const vector<column_t> &column_ids,
TableIOManager &table_io_manager, const vector<unique_ptr<Expression>> &unbound_expressions,
AttachedDatabase &db, const case_insensitive_map_t<Value> &options, const IndexStorageInfo &info,
idx_t estimated_cardinality)
: BoundIndex(name, TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db) {
TableIOManager &table_io_manager, const vector<unique_ptr<Expression>> &unbound_expressions,
AttachedDatabase &db, const case_insensitive_map_t<Value> &options, const IndexStorageInfo &info,
idx_t estimated_cardinality)
: BoundIndex(name, TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db) {

if (index_constraint_type != IndexConstraintType::NONE) {
throw NotImplementedException("HNSW indexes do not support unique or primary key constraints");
Expand Down Expand Up @@ -202,14 +202,13 @@ HNSWIndex::HNSWIndex(const string &name, IndexConstraintType index_constraint_ty
if (!info.allocator_infos[0].buffer_ids.empty()) {
LinkedBlockReader reader(*linked_block_allocator, root_block_ptr);
index.load_from_stream(
[&](void *data, size_t size) { return size == reader.ReadData(static_cast<data_ptr_t>(data), size); });
[&](void *data, size_t size) { return size == reader.ReadData(static_cast<data_ptr_t>(data), size); });
}
} else {
index.reserve(MinValue(static_cast<idx_t>(32), estimated_cardinality));
}
index_size = index.size();


function_matcher = MakeFunctionMatcher();
}

Expand All @@ -231,31 +230,31 @@ string HNSWIndex::GetMetric() const {
}

const case_insensitive_map_t<unum::usearch::metric_kind_t> HNSWIndex::METRIC_KIND_MAP = {
{"l2sq", unum::usearch::metric_kind_t::l2sq_k},
{"cosine", unum::usearch::metric_kind_t::cos_k},
{"ip", unum::usearch::metric_kind_t::ip_k},
/* TODO: Add the rest of these later
{"divergence", unum::usearch::metric_kind_t::divergence_k},
{"hamming", unum::usearch::metric_kind_t::hamming_k},
{"jaccard", unum::usearch::metric_kind_t::jaccard_k},
{"haversine", unum::usearch::metric_kind_t::haversine_k},
{"pearson", unum::usearch::metric_kind_t::pearson_k},
{"sorensen", unum::usearch::metric_kind_t::sorensen_k},
{"tanimoto", unum::usearch::metric_kind_t::tanimoto_k}
*/
{"l2sq", unum::usearch::metric_kind_t::l2sq_k},
{"cosine", unum::usearch::metric_kind_t::cos_k},
{"ip", unum::usearch::metric_kind_t::ip_k},
/* TODO: Add the rest of these later
{"divergence", unum::usearch::metric_kind_t::divergence_k},
{"hamming", unum::usearch::metric_kind_t::hamming_k},
{"jaccard", unum::usearch::metric_kind_t::jaccard_k},
{"haversine", unum::usearch::metric_kind_t::haversine_k},
{"pearson", unum::usearch::metric_kind_t::pearson_k},
{"sorensen", unum::usearch::metric_kind_t::sorensen_k},
{"tanimoto", unum::usearch::metric_kind_t::tanimoto_k}
*/
};

const unordered_map<uint8_t, unum::usearch::scalar_kind_t> HNSWIndex::SCALAR_KIND_MAP = {
{static_cast<uint8_t>(LogicalTypeId::FLOAT), unum::usearch::scalar_kind_t::f32_k},
{static_cast<uint8_t>(LogicalTypeId::DOUBLE), unum::usearch::scalar_kind_t::f64_k},
{static_cast<uint8_t>(LogicalTypeId::TINYINT), unum::usearch::scalar_kind_t::i8_k},
{static_cast<uint8_t>(LogicalTypeId::SMALLINT), unum::usearch::scalar_kind_t::i16_k},
{static_cast<uint8_t>(LogicalTypeId::INTEGER), unum::usearch::scalar_kind_t::i32_k},
{static_cast<uint8_t>(LogicalTypeId::BIGINT), unum::usearch::scalar_kind_t::i64_k},
{static_cast<uint8_t>(LogicalTypeId::UTINYINT), unum::usearch::scalar_kind_t::u8_k},
{static_cast<uint8_t>(LogicalTypeId::USMALLINT), unum::usearch::scalar_kind_t::u16_k},
{static_cast<uint8_t>(LogicalTypeId::UINTEGER), unum::usearch::scalar_kind_t::u32_k},
{static_cast<uint8_t>(LogicalTypeId::UBIGINT), unum::usearch::scalar_kind_t::u64_k}};
{static_cast<uint8_t>(LogicalTypeId::FLOAT), unum::usearch::scalar_kind_t::f32_k},
{static_cast<uint8_t>(LogicalTypeId::DOUBLE), unum::usearch::scalar_kind_t::f64_k},
{static_cast<uint8_t>(LogicalTypeId::TINYINT), unum::usearch::scalar_kind_t::i8_k},
{static_cast<uint8_t>(LogicalTypeId::SMALLINT), unum::usearch::scalar_kind_t::i16_k},
{static_cast<uint8_t>(LogicalTypeId::INTEGER), unum::usearch::scalar_kind_t::i32_k},
{static_cast<uint8_t>(LogicalTypeId::BIGINT), unum::usearch::scalar_kind_t::i64_k},
{static_cast<uint8_t>(LogicalTypeId::UTINYINT), unum::usearch::scalar_kind_t::u8_k},
{static_cast<uint8_t>(LogicalTypeId::USMALLINT), unum::usearch::scalar_kind_t::u16_k},
{static_cast<uint8_t>(LogicalTypeId::UINTEGER), unum::usearch::scalar_kind_t::u32_k},
{static_cast<uint8_t>(LogicalTypeId::UBIGINT), unum::usearch::scalar_kind_t::u64_k}};

unique_ptr<HNSWIndexStats> HNSWIndex::GetStats() {
auto lock = rwlock.GetExclusiveLock();
Expand Down Expand Up @@ -375,7 +374,6 @@ void HNSWIndex::ResetMultiScan(IndexScanState &state) {
scan_state.row_ids.clear();
}


void HNSWIndex::CommitDrop(IndexLock &index_lock) {
// Acquire an exclusive lock to drop the index
auto lock = rwlock.GetExclusiveLock();
Expand Down Expand Up @@ -573,18 +571,18 @@ void HNSWIndex::VerifyAllocations(IndexLock &state) {
// Can rewrite index expression?
//------------------------------------------------------------------------------
static void TryBindIndexExpressionInternal(Expression &expr, idx_t table_idx, const vector<column_t> &index_columns,
const vector<column_t> &table_columns, bool &success, bool &found) {
const vector<column_t> &table_columns, bool &success, bool &found) {

if(expr.type == ExpressionType::BOUND_COLUMN_REF) {
if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
found = true;
auto &ref = expr.Cast<BoundColumnRefExpression>();

// Rewrite the column reference to fit in the current set of bound column ids
ref.binding.table_index = table_idx;

const auto referenced_column = index_columns[ref.binding.column_index];
for(idx_t i = 0; i < table_columns.size(); i++) {
if(table_columns[i] == referenced_column) {
for (idx_t i = 0; i < table_columns.size(); i++) {
if (table_columns[i] == referenced_column) {
ref.binding.column_index = i;
return;
}
Expand All @@ -609,31 +607,32 @@ bool HNSWIndex::TryBindIndexExpression(LogicalGet &get, unique_ptr<Expression> &

TryBindIndexExpressionInternal(expr, get.table_index, index_columns, table_columns, success, found);

if(success && found) {
if (success && found) {
result = std::move(expr_ptr);
return true;
}
return false;
}

bool HNSWIndex::TryMatchDistanceFunction(const unique_ptr<Expression>& expr, vector<reference<Expression>> &bindings) const {
bool HNSWIndex::TryMatchDistanceFunction(const unique_ptr<Expression> &expr,
vector<reference<Expression>> &bindings) const {
return function_matcher->Match(*expr, bindings);
}

unique_ptr<ExpressionMatcher> HNSWIndex::MakeFunctionMatcher() const {
unordered_set<string> distance_functions;
switch(index.metric().metric_kind()) {
case unum::usearch::metric_kind_t::l2sq_k:
distance_functions = { "array_distance", "<->" };
break;
case unum::usearch::metric_kind_t::cos_k:
distance_functions = { "array_cosine_distance", "<=>" };
break;
case unum::usearch::metric_kind_t::ip_k:
distance_functions = { "array_negative_inner_product", "<#>" };
break;
default:
throw NotImplementedException("Unknown metric kind");
switch (index.metric().metric_kind()) {
case unum::usearch::metric_kind_t::l2sq_k:
distance_functions = {"array_distance", "<->"};
break;
case unum::usearch::metric_kind_t::cos_k:
distance_functions = {"array_cosine_distance", "<=>"};
break;
case unum::usearch::metric_kind_t::ip_k:
distance_functions = {"array_negative_inner_product", "<#>"};
break;
default:
throw NotImplementedException("Unknown metric kind");
}

auto matcher = make_uniq<FunctionExpressionMatcher>();
Expand Down
Loading

0 comments on commit 7ae0e10

Please sign in to comment.