Skip to content

Commit

Permalink
add index info pragma, only persist index if marked as dirty
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Mar 20, 2024
1 parent 882cef9 commit c80288a
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/hnsw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_logical_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_physical_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_scan.cpp
Expand Down
49 changes: 48 additions & 1 deletion src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace duckdb {

class LinkedBlock {
public:
// TODO: More testing with small block sizes. 64 works though.
static constexpr const idx_t BLOCK_SIZE = Storage::BLOCK_SIZE - sizeof(validity_t);
static constexpr const idx_t BLOCK_DATA_SIZE = BLOCK_SIZE - sizeof(IndexPointer);
static_assert(BLOCK_SIZE > sizeof(IndexPointer), "Block size must be larger than the size of an IndexPointer");
Expand Down Expand Up @@ -181,6 +180,19 @@ idx_t HNSWIndex::GetVectorSize() const {
return index.dimensions();
}

string HNSWIndex::GetMetric() const {
switch (index.metric().metric_kind()) {
case unum::usearch::metric_kind_t::l2sq_k:
return "l2sq";
case unum::usearch::metric_kind_t::cos_k:
return "cosine";
case unum::usearch::metric_kind_t::ip_k:
return "ip";
default:
throw InternalException("Unknown metric kind");
}
}

bool HNSWIndex::IsDistanceFunction(const string &distance_function_name) {
auto accepted_functions = {"array_distance", "array_cosine_similarity", "array_inner_product"};
return std::find(accepted_functions.begin(), accepted_functions.end(), distance_function_name) != accepted_functions.end();
Expand Down Expand Up @@ -226,6 +238,23 @@ const unordered_map<uint8_t, unum::usearch::scalar_kind_t> HNSWIndex::SCALAR_KIN
{static_cast<uint8_t>(LogicalTypeId::UINTEGER), unum::usearch::scalar_kind_t::u32_k},
{static_cast<uint8_t>(LogicalTypeId::UBIGINT), unum::usearch::scalar_kind_t::u64_k}};


unique_ptr<HNSWIndexStats> HNSWIndex::GetStats() {
auto lock = rwlock.GetExclusiveLock();
auto result = make_uniq<HNSWIndexStats>();

result->max_level = index.max_level();
result->count = index.size();
result->capacity = index.capacity();
result->approx_size = index.memory_usage();

for(idx_t i = 0; i < index.max_level(); i++) {
result->level_stats.push_back(index.stats(i));
}

return result;
}

// Scan State
struct HNSWIndexScanState : public IndexScanState {
idx_t current_row = 0;
Expand Down Expand Up @@ -278,6 +307,9 @@ void HNSWIndex::Construct(DataChunk &input, Vector &row_ids, idx_t thread_idx) {
D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE);
D_ASSERT(logical_types[0] == input.data[0].GetType());

// Mark this index as dirty so we checkpoint it properly
is_dirty = true;

auto count = input.size();
input.Flatten();

Expand Down Expand Up @@ -318,6 +350,11 @@ void HNSWIndex::Construct(DataChunk &input, Vector &row_ids, idx_t thread_idx) {
}

void HNSWIndex::Compact() {
// Mark this index as dirty so we checkpoint it properly
is_dirty = true;

// Acquire an exclusive lock to compact the index
auto lock = rwlock.GetExclusiveLock();
// Re-compact the index
auto result = index.compact();
if(!result) {
Expand All @@ -326,6 +363,9 @@ void HNSWIndex::Compact() {
}

void HNSWIndex::Delete(IndexLock &lock, DataChunk &input, Vector &rowid_vec) {
// Mark this index as dirty so we checkpoint it properly
is_dirty = true;

auto count = input.size();
rowid_vec.Flatten(count);
auto row_id_data = FlatVector::GetData<row_t>(rowid_vec);
Expand Down Expand Up @@ -360,6 +400,11 @@ void HNSWIndex::PersistToDisk() {
// Acquire an exclusive lock to persist the index
auto lock = rwlock.GetExclusiveLock();

// If there haven't been any changes, we don't need to rewrite the index again
if(!is_dirty) {
return;
}

// Write

if (root_block_ptr.Get() == 0) {
Expand All @@ -372,6 +417,8 @@ void HNSWIndex::PersistToDisk() {
writer.WriteData(static_cast<const_data_ptr_t>(data), size);
return true;
});

is_dirty = false;
}

IndexStorageInfo HNSWIndex::GetStorageInfo(const bool get_buffers) {
Expand Down
203 changes: 203 additions & 0 deletions src/hnsw/hnsw_index_pragmas.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
#include "duckdb/catalog/dependency_list.hpp"
#include "duckdb/common/mutex.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/optimizer/matcher/expression_matcher.hpp"
#include "duckdb/planner/expression_iterator.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "duckdb/storage/table/scan_state.hpp"
#include "duckdb/transaction/duck_transaction.hpp"
#include "duckdb/transaction/local_storage.hpp"
#include "duckdb/main/extension_util.hpp"
#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp"
#include "duckdb/storage/data_table.hpp"

#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"
#include "hnsw/hnsw_index_scan.hpp"

namespace duckdb {

// BIND
static unique_ptr<FunctionData> HNSWindexInfoBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
names.emplace_back("catalog_name");
return_types.emplace_back(LogicalType::VARCHAR);

names.emplace_back("schema_name");
return_types.emplace_back(LogicalType::VARCHAR);

names.emplace_back("index_name");
return_types.emplace_back(LogicalType::VARCHAR);

names.emplace_back("table_name");
return_types.emplace_back(LogicalType::VARCHAR);

names.emplace_back("metric");
return_types.emplace_back(LogicalType::VARCHAR);

names.emplace_back("dimensions");
return_types.emplace_back(LogicalType::BIGINT);

names.emplace_back("count");
return_types.emplace_back(LogicalType::BIGINT);

names.emplace_back("capacity");
return_types.emplace_back(LogicalType::BIGINT);

names.emplace_back("approx_memory_usage");
return_types.emplace_back(LogicalType::BIGINT);

names.emplace_back("levels");
return_types.emplace_back(LogicalType::BIGINT);

names.emplace_back("levels_stats");
return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT({
{"nodes", LogicalType::BIGINT},
{"edges", LogicalType::BIGINT},
{"max_edges", LogicalType::BIGINT},
{"allocated_bytes", LogicalType::BIGINT}
})));

return nullptr;
}

// INIT GLOBAL
struct HNSWIndexInfoGlobalState : public GlobalTableFunctionState {
idx_t offset = 0;
vector<reference<IndexCatalogEntry>> entries;
};

static unique_ptr<GlobalTableFunctionState> HNSWIndexInfoInitGlobal(ClientContext &context,
TableFunctionInitInput &input) {
auto result = make_uniq<HNSWIndexInfoGlobalState>();

// scan all the schemas for indexes and collect them
auto schemas = Catalog::GetAllSchemas(context);
for (auto &schema : schemas) {
schema.get().Scan(context, CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) {
auto &index_entry = entry.Cast<IndexCatalogEntry>();
if(index_entry.index_type == HNSWIndex::TYPE_NAME) {
result->entries.push_back(index_entry);
}
});
};
return std::move(result);
}

// EXECUTE
static void HNSWIndexInfoExecute(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &data = data_p.global_state->Cast<HNSWIndexInfoGlobalState>();
if(data.offset >= data.entries.size()) {
return;
}

idx_t row = 0;
while (data.offset < data.entries.size() && row < STANDARD_VECTOR_SIZE) {
auto &index_entry = data.entries[data.offset++].get();
auto &table_entry = index_entry.schema.catalog.GetEntry<TableCatalogEntry>(context, index_entry.GetSchemaName(), index_entry.GetTableName());
auto &storage = table_entry.GetStorage();
HNSWIndex* hnsw_index = nullptr;

storage.info->indexes.Scan([&](Index &index) {
if(index.name == index_entry.name && index.index_type == HNSWIndex::TYPE_NAME) {
hnsw_index = &index.Cast<HNSWIndex>();
return true;
}
return false;
});
if(!hnsw_index) {
throw BinderException("Index %s not found", index_entry.name);
}

idx_t col = 0;

output.data[col++].SetValue(row, Value(index_entry.catalog.GetName()));
output.data[col++].SetValue(row, Value(index_entry.schema.name));
output.data[col++].SetValue(row, Value(index_entry.name));
output.data[col++].SetValue(row, Value(table_entry.name));

auto stats = hnsw_index->GetStats();

output.data[col++].SetValue(row, Value(hnsw_index->GetMetric()));
output.data[col++].SetValue(row, Value::BIGINT(hnsw_index->GetVectorSize()));
output.data[col++].SetValue(row, Value::BIGINT(stats->count));
output.data[col++].SetValue(row, Value::BIGINT(stats->capacity));
output.data[col++].SetValue(row, Value::BIGINT(stats->approx_size));
output.data[col++].SetValue(row, Value::BIGINT(stats->max_level));

vector<Value> level_stats;
for (auto &stat : stats->level_stats) {
level_stats.push_back(Value::STRUCT({{"nodes", Value::BIGINT(stat.nodes)},
{"edges", Value::BIGINT(stat.edges)},
{"max_edges", Value::BIGINT(stat.max_edges)},
{"allocated_bytes", Value::BIGINT(stat.allocated_bytes)}}));
}
auto level_stat_value = Value::LIST(LogicalType::STRUCT({{
{"nodes", LogicalType::BIGINT},
{"edges", LogicalType::BIGINT},
{"max_edges", LogicalType::BIGINT},
{"allocated_bytes", LogicalType::BIGINT}
}}), level_stats);

output.data[col++].SetValue(row, level_stat_value);

row++;
}
output.SetCardinality(row);
}



//-------------------------------------------------------------------------
// Compact PRAGMA
//-------------------------------------------------------------------------

static void CompactIndexPragma(ClientContext &context, const FunctionParameters &parameters) {
if(parameters.values.size() != 1) {
throw BinderException("Expected one argument for hnsw_compact_index");
}
auto &param = parameters.values[0];
if(param.type() != LogicalType::VARCHAR) {
throw BinderException("Expected a string argument for hnsw_compact_index");
}
auto index_name = param.GetValue<string>();


auto qname = QualifiedName::Parse(index_name);

// look up the index name in the catalog
Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema);
auto &index_entry = Catalog::GetEntry(context, CatalogType::INDEX_ENTRY, qname.catalog, qname.schema, qname.name).Cast<IndexCatalogEntry>();
auto &table_entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, index_entry.GetSchemaName(), index_entry.GetTableName()).Cast<TableCatalogEntry>();

auto &storage = table_entry.GetStorage();
bool found_index = false;
storage.info->indexes.Scan([&](Index &index_entry) {
if(index_entry.name == index_name && index_entry.index_type == HNSWIndex::TYPE_NAME) {
auto &hnsw_index = index_entry.Cast<HNSWIndex>();
hnsw_index.Compact();
found_index = true;
return true;
}
return false;
});

if(!found_index) {
throw BinderException("Index %s not found", index_name);
}
}

//-------------------------------------------------------------------------
// Register
//-------------------------------------------------------------------------
void HNSWModule::RegisterIndexPragmas(DatabaseInstance &db) {
ExtensionUtil::RegisterFunction(db, PragmaFunction::PragmaCall("hnsw_compact_index", CompactIndexPragma, {LogicalType::VARCHAR}));

// TODO: This is kind of ugly and maybe should just take a parameter instead...
TableFunction info_function("pragma_hnsw_index_info", {}, HNSWIndexInfoExecute, HNSWindexInfoBind, HNSWIndexInfoInitGlobal);
ExtensionUtil::RegisterFunction(db, info_function);

}

} // namespace duckdb
40 changes: 0 additions & 40 deletions src/hnsw/hnsw_index_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,51 +148,11 @@ TableFunction HNSWIndexScanFunction::GetFunction() {
return func;
}

//-------------------------------------------------------------------------
// Compact PRAGMA
//-------------------------------------------------------------------------

static void CompactIndexPragma(ClientContext &context, const FunctionParameters &parameters) {
if(parameters.values.size() != 1) {
throw BinderException("Expected one argument for hnsw_compact_index");
}
auto &param = parameters.values[0];
if(param.type() != LogicalType::VARCHAR) {
throw BinderException("Expected a string argument for hnsw_compact_index");
}
auto index_name = param.GetValue<string>();


auto qname = QualifiedName::Parse(index_name);

// look up the index name in the catalog
Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema);
auto &index_entry = Catalog::GetEntry(context, CatalogType::INDEX_ENTRY, qname.catalog, qname.schema, qname.name).Cast<IndexCatalogEntry>();
auto &table_entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, index_entry.GetSchemaName(), index_entry.GetTableName()).Cast<TableCatalogEntry>();

auto &storage = table_entry.GetStorage();
bool found_index = false;
storage.info->indexes.Scan([&](Index &index_entry) {
if(index_entry.name == index_name) {
auto &hnsw_index = index_entry.Cast<HNSWIndex>();
hnsw_index.Compact();
found_index = true;
return true;
}
return false;
});

if(!found_index) {
throw BinderException("Index %s not found", index_name);
}
}

//-------------------------------------------------------------------------
// Register
//-------------------------------------------------------------------------
void HNSWModule::RegisterIndexScan(DatabaseInstance &db) {
ExtensionUtil::RegisterFunction(db, HNSWIndexScanFunction::GetFunction());
ExtensionUtil::RegisterFunction(db, PragmaFunction::PragmaCall("hnsw_compact_index", CompactIndexPragma, {LogicalType::VARCHAR}));
}

} // namespace duckdb
2 changes: 2 additions & 0 deletions src/include/hnsw/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ struct HNSWModule {
static void Register(DatabaseInstance &db) {
RegisterIndex(db);
RegisterIndexScan(db);
RegisterIndexPragmas(db);
RegisterPlanIndexScan(db);
RegisterPlanIndexCreate(db);
}

private:
static void RegisterIndex(DatabaseInstance &db);
static void RegisterIndexScan(DatabaseInstance &db);
static void RegisterIndexPragmas(DatabaseInstance &db);
static void RegisterPlanIndexScan(DatabaseInstance &db);
static void RegisterPlanIndexCreate(DatabaseInstance &db);
};
Expand Down
Loading

0 comments on commit c80288a

Please sign in to comment.