From 58f40f7001651b3fcddc76eb6f567344739e9ecc Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Tue, 19 Mar 2024 20:38:24 +0100 Subject: [PATCH] update usearch, add lock, we should switch to merging instead --- README.md | 4 +- src/hnsw/hnsw_index.cpp | 26 +- src/include/usearch/index.hpp | 6961 ++++++++++++------------- src/include/usearch/index_dense.hpp | 820 ++- src/include/usearch/index_plugins.hpp | 1207 ++--- 5 files changed, 4101 insertions(+), 4917 deletions(-) diff --git a/README.md b/README.md index a41912e..8a0b4bc 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # DuckDB-VSS -Vector Similarity Search Extension (based on usearch) +Vector Similarity Search for DuckDB This is an experimental extension for DuckDB that adds indexing support to accelerate Vector Similarity Search using DuckDB's new fixed-size `ARRAY` type added in version v0.10.0. This extension is based on the [usearch](https://github.com/unum-cloud/usearch) library and serves as a proof of concept for providing a custom index type, in this case a HNSW index, from within an extension and exposing it to DuckDB. ## Usage -To create a new HNSW index on a table, use the `CREATE INDEX` statement with the `USING HNSW` clause. For example: +To create a new HNSW index on a table with an `ARRAY` column, use the `CREATE INDEX` statement with the `USING HNSW` clause. For example: ```sql CREATE TABLE my_vector_table (vec FLOAT[3]); INSERT INTO my_vector_table SELECT array_value(a,b,c) FROM range(1,10) ra(a), range(1,10) rb(b), range(1,10) rc(c); diff --git a/src/hnsw/hnsw_index.cpp b/src/hnsw/hnsw_index.cpp index b67d50f..f19faf0 100644 --- a/src/hnsw/hnsw_index.cpp +++ b/src/hnsw/hnsw_index.cpp @@ -267,6 +267,18 @@ void HNSWIndex::CommitDrop(IndexLock &index_lock) { root_block_ptr.Clear(); } +inline idx_t NextPowerOfTwo(idx_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + void HNSWIndex::Construct(DataChunk &input, Vector &row_ids, idx_t thread_idx) { D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); D_ASSERT(logical_types[0] == input.data[0].GetType()); @@ -274,15 +286,23 @@ void HNSWIndex::Construct(DataChunk &input, Vector &row_ids, idx_t thread_idx) { auto count = input.size(); input.Flatten(); - // TODO: Do we need to track this atomically globally? - index.reserve(index.capacity() + count); - auto &vec_vec = input.data[0]; auto &vec_child_vec = ArrayVector::GetEntry(vec_vec); auto array_size = ArrayType::GetSize(vec_vec.GetType()); auto vec_child_data = FlatVector::GetData(vec_child_vec); auto rowid_data = FlatVector::GetData(row_ids); + + // lock_guard lock(hnsw_index_mutex); + // TODO: Do we need to track this atomically globally? + // Better strategy: Create multiple small indexes and merge! + static mutex hnsw_index_mutex; + lock_guard lock(hnsw_index_mutex); + + if(!index.reserve(NextPowerOfTwo(index.size() + count))) { + throw InternalException("Failed to reserve space in the HNSW index"); + } + for (idx_t out_idx = 0; out_idx < count; out_idx++) { auto rowid = rowid_data[out_idx]; auto result = index.add(rowid, vec_child_data + (out_idx * array_size), thread_idx); diff --git a/src/include/usearch/index.hpp b/src/include/usearch/index.hpp index 0323b4b..6e5150d 100644 --- a/src/include/usearch/index.hpp +++ b/src/include/usearch/index.hpp @@ -1,17 +1,17 @@ /** - * @file index.hpp - * @author Ash Vardanian - * @brief Single-header Vector Search. - * @date 2023-04-26 - * - * @copyright Copyright (c) 2023 - */ +* @file index.hpp +* @author Ash Vardanian +* @brief Single-header Vector Search. +* @date 2023-04-26 +* +* @copyright Copyright (c) 2023 +*/ #ifndef UNUM_USEARCH_HPP #define UNUM_USEARCH_HPP #define USEARCH_VERSION_MAJOR 2 -#define USEARCH_VERSION_MINOR 8 -#define USEARCH_VERSION_PATCH 14 +#define USEARCH_VERSION_MINOR 9 +#define USEARCH_VERSION_PATCH 2 // Inferring C++ version // https://stackoverflow.com/a/61552074 @@ -92,9 +92,9 @@ // Zero means we are only going to read from that memory. // Three means high temporal locality and suggests to keep // the data in all layers of cache. -#define prefetch_m(ptr) __builtin_prefetch((void *)(ptr), 0, 3) +#define prefetch_m(ptr) __builtin_prefetch((void*)(ptr), 0, 3) #elif defined(USEARCH_DEFINED_X86) -#define prefetch_m(ptr) _mm_prefetch((void *)(ptr), _MM_HINT_T0) +#define prefetch_m(ptr) _mm_prefetch((void*)(ptr), _MM_HINT_T0) #else #define prefetch_m(ptr) #endif @@ -104,7 +104,7 @@ #define usearch_pack_m #define usearch_align_m __declspec(align(64)) #else -#define usearch_pack_m __attribute__((packed)) +#define usearch_pack_m __attribute__((packed)) #define usearch_align_m __attribute__((aligned(64))) #endif @@ -114,9 +114,9 @@ #define usearch_noexcept_m noexcept #else #define usearch_assert_m(must_be_true, message) \ - if (!(must_be_true)) { \ - throw std::runtime_error(message); \ - } + if (!(must_be_true)) { \ + throw std::runtime_error(message); \ + } #define usearch_noexcept_m #endif @@ -125,783 +125,632 @@ namespace usearch { using byte_t = char; -template -std::size_t divide_round_up(std::size_t num) noexcept { - return (num + multiple_ak - 1) / multiple_ak; +template std::size_t divide_round_up(std::size_t num) noexcept { + return (num + multiple_ak - 1) / multiple_ak; } inline std::size_t divide_round_up(std::size_t num, std::size_t denominator) noexcept { - return (num + denominator - 1) / denominator; + return (num + denominator - 1) / denominator; } inline std::size_t ceil2(std::size_t v) noexcept { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; #ifdef USEARCH_64BIT_ENV - v |= v >> 32; + v |= v >> 32; #endif - v++; - return v; + v++; + return v; } /// @brief Simply dereferencing misaligned pointers can be dangerous. -template -void misaligned_store(void *ptr, at v) noexcept { - static_assert(!std::is_reference::value, "Can't store a reference"); - std::memcpy(ptr, &v, sizeof(at)); +template void misaligned_store(void* ptr, at v) noexcept { + static_assert(!std::is_reference::value, "Can't store a reference"); + std::memcpy(ptr, &v, sizeof(at)); } /// @brief Simply dereferencing misaligned pointers can be dangerous. -template -at misaligned_load(void *ptr) noexcept { - static_assert(!std::is_reference::value, "Can't load a reference"); - at v; - std::memcpy(&v, ptr, sizeof(at)); - return v; +template at misaligned_load(void* ptr) noexcept { + static_assert(!std::is_reference::value, "Can't load a reference"); + at v; + std::memcpy(&v, ptr, sizeof(at)); + return v; } /// @brief The `std::exchange` alternative for C++11. -template -at exchange(at &obj, other_at &&new_value) { - at old_value = std::move(obj); - obj = std::forward(new_value); - return old_value; +template at exchange(at& obj, other_at&& new_value) { + at old_value = std::move(obj); + obj = std::forward(new_value); + return old_value; } /// @brief The `std::destroy_at` alternative for C++11. template -typename std::enable_if::value>::type destroy_at(at *) { -} +typename std::enable_if::value>::type destroy_at(at*) {} template -typename std::enable_if::value>::type destroy_at(at *obj) { - obj->~sfinae_at(); +typename std::enable_if::value>::type destroy_at(at* obj) { + obj->~sfinae_at(); } /// @brief The `std::construct_at` alternative for C++11. template -typename std::enable_if::value>::type construct_at(at *) { -} +typename std::enable_if::value>::type construct_at(at*) {} template -typename std::enable_if::value>::type construct_at(at *obj) { - new (obj) at(); +typename std::enable_if::value>::type construct_at(at* obj) { + new (obj) at(); } /** - * @brief A reference to a misaligned memory location with a specific type. - * It is needed to avoid Undefined Behavior when dereferencing addresses - * indivisible by `sizeof(at)`. - */ -template -class misaligned_ref_gt { - using element_t = at; - using mutable_t = typename std::remove_const::type; - byte_t *ptr_; +* @brief A reference to a misaligned memory location with a specific type. +* It is needed to avoid Undefined Behavior when dereferencing addresses +* indivisible by `sizeof(at)`. +*/ +template class misaligned_ref_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; public: - misaligned_ref_gt(byte_t *ptr) noexcept : ptr_(ptr) { - } - operator mutable_t() const noexcept { - return misaligned_load(ptr_); - } - misaligned_ref_gt &operator=(mutable_t const &v) noexcept { - misaligned_store(ptr_, v); - return *this; - } - - void reset(byte_t *ptr) noexcept { - ptr_ = ptr; - } - byte_t *ptr() const noexcept { - return ptr_; - } + misaligned_ref_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + operator mutable_t() const noexcept { return misaligned_load(ptr_); } + misaligned_ref_gt& operator=(mutable_t const& v) noexcept { + misaligned_store(ptr_, v); + return *this; + } + + void reset(byte_t* ptr) noexcept { ptr_ = ptr; } + byte_t* ptr() const noexcept { return ptr_; } }; /** - * @brief A pointer to a misaligned memory location with a specific type. - * It is needed to avoid Undefined Behavior when dereferencing addresses - * indivisible by `sizeof(at)`. - */ -template -class misaligned_ptr_gt { - using element_t = at; - using mutable_t = typename std::remove_const::type; - byte_t *ptr_; +* @brief A pointer to a misaligned memory location with a specific type. +* It is needed to avoid Undefined Behavior when dereferencing addresses +* indivisible by `sizeof(at)`. +*/ +template class misaligned_ptr_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; public: - using iterator_category = std::random_access_iterator_tag; - using value_type = element_t; - using difference_type = std::ptrdiff_t; - using pointer = misaligned_ptr_gt; - using reference = misaligned_ref_gt; - - reference operator*() const noexcept { - return {ptr_}; - } - reference operator[](std::size_t i) noexcept { - return reference(ptr_ + i * sizeof(element_t)); - } - value_type operator[](std::size_t i) const noexcept { - return misaligned_load(ptr_ + i * sizeof(element_t)); - } - - misaligned_ptr_gt(byte_t *ptr) noexcept : ptr_(ptr) { - } - misaligned_ptr_gt operator++(int) noexcept { - return misaligned_ptr_gt(ptr_ + sizeof(element_t)); - } - misaligned_ptr_gt operator--(int) noexcept { - return misaligned_ptr_gt(ptr_ - sizeof(element_t)); - } - misaligned_ptr_gt operator+(difference_type d) noexcept { - return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); - } - misaligned_ptr_gt operator-(difference_type d) noexcept { - return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); - } - - // clang-format off + using iterator_category = std::random_access_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return {ptr_}; } + reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } + value_type operator[](std::size_t i) const noexcept { + return misaligned_load(ptr_ + i * sizeof(element_t)); + } + + misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + misaligned_ptr_gt operator++(int) noexcept { return misaligned_ptr_gt(ptr_ + sizeof(element_t)); } + misaligned_ptr_gt operator--(int) noexcept { return misaligned_ptr_gt(ptr_ - sizeof(element_t)); } + misaligned_ptr_gt operator+(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); } + misaligned_ptr_gt operator-(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); } + + // clang-format off misaligned_ptr_gt& operator++() noexcept { ptr_ += sizeof(element_t); return *this; } misaligned_ptr_gt& operator--() noexcept { ptr_ -= sizeof(element_t); return *this; } misaligned_ptr_gt& operator+=(difference_type d) noexcept { ptr_ += d * sizeof(element_t); return *this; } misaligned_ptr_gt& operator-=(difference_type d) noexcept { ptr_ -= d * sizeof(element_t); return *this; } - // clang-format on - - bool operator==(misaligned_ptr_gt const &other) noexcept { - return ptr_ == other.ptr_; - } - bool operator!=(misaligned_ptr_gt const &other) noexcept { - return ptr_ != other.ptr_; - } + // clang-format on + + bool operator==(misaligned_ptr_gt const& other) noexcept { return ptr_ == other.ptr_; } + bool operator!=(misaligned_ptr_gt const& other) noexcept { return ptr_ != other.ptr_; } }; /** - * @brief Non-owning memory range view, similar to `std::span`, but for C++11. - */ -template -class span_gt { - scalar_at *data_; - std::size_t size_; +* @brief Non-owning memory range view, similar to `std::span`, but for C++11. +*/ +template class span_gt { + scalar_at* data_; + std::size_t size_; public: - span_gt() noexcept : data_(nullptr), size_(0u) { - } - span_gt(scalar_at *begin, scalar_at *end) noexcept : data_(begin), size_(end - begin) { - } - span_gt(scalar_at *begin, std::size_t size) noexcept : data_(begin), size_(size) { - } - scalar_at *data() const noexcept { - return data_; - } - std::size_t size() const noexcept { - return size_; - } - scalar_at *begin() const noexcept { - return data_; - } - scalar_at *end() const noexcept { - return data_ + size_; - } - operator scalar_at *() const noexcept { - return data(); - } + span_gt() noexcept : data_(nullptr), size_(0u) {} + span_gt(scalar_at* begin, scalar_at* end) noexcept : data_(begin), size_(end - begin) {} + span_gt(scalar_at* begin, std::size_t size) noexcept : data_(begin), size_(size) {} + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } }; /** - * @brief Similar to `std::vector`, but doesn't support dynamic resizing. - * On the bright side, this can't throw exceptions. - */ -template > -class buffer_gt { - scalar_at *data_; - std::size_t size_; +* @brief Similar to `std::vector`, but doesn't support dynamic resizing. +* On the bright side, this can't throw exceptions. +*/ +template > class buffer_gt { + scalar_at* data_; + std::size_t size_; public: - buffer_gt() noexcept : data_(nullptr), size_(0u) { - } - buffer_gt(std::size_t size) noexcept : data_(allocator_at {}.allocate(size)), size_(data_ ? size : 0u) { - if (!std::is_trivially_default_constructible::value) - for (std::size_t i = 0; i != size_; ++i) - construct_at(data_ + i); - } - ~buffer_gt() noexcept { - if (!std::is_trivially_destructible::value) - for (std::size_t i = 0; i != size_; ++i) - destroy_at(data_ + i); - allocator_at {}.deallocate(data_, size_); - data_ = nullptr; - size_ = 0; - } - scalar_at *data() const noexcept { - return data_; - } - std::size_t size() const noexcept { - return size_; - } - scalar_at *begin() const noexcept { - return data_; - } - scalar_at *end() const noexcept { - return data_ + size_; - } - operator scalar_at *() const noexcept { - return data(); - } - scalar_at &operator[](std::size_t i) noexcept { - return data_[i]; - } - scalar_at const &operator[](std::size_t i) const noexcept { - return data_[i]; - } - explicit operator bool() const noexcept { - return data_; - } - scalar_at *release() noexcept { - size_ = 0; - return exchange(data_, nullptr); - } - - buffer_gt(buffer_gt const &) = delete; - buffer_gt &operator=(buffer_gt const &) = delete; - - buffer_gt(buffer_gt &&other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) { - } - buffer_gt &operator=(buffer_gt &&other) noexcept { - std::swap(data_, other.data_); - std::swap(size_, other.size_); - return *this; - } + buffer_gt() noexcept : data_(nullptr), size_(0u) {} + buffer_gt(std::size_t size) noexcept : data_(allocator_at{}.allocate(size)), size_(data_ ? size : 0u) { + if (!std::is_trivially_default_constructible::value) + for (std::size_t i = 0; i != size_; ++i) + construct_at(data_ + i); + } + ~buffer_gt() noexcept { + if (!std::is_trivially_destructible::value) + for (std::size_t i = 0; i != size_; ++i) + destroy_at(data_ + i); + allocator_at{}.deallocate(data_, size_); + data_ = nullptr; + size_ = 0; + } + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } + scalar_at& operator[](std::size_t i) noexcept { return data_[i]; } + scalar_at const& operator[](std::size_t i) const noexcept { return data_[i]; } + explicit operator bool() const noexcept { return data_; } + scalar_at* release() noexcept { + size_ = 0; + return exchange(data_, nullptr); + } + + buffer_gt(buffer_gt const&) = delete; + buffer_gt& operator=(buffer_gt const&) = delete; + + buffer_gt(buffer_gt&& other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) {} + buffer_gt& operator=(buffer_gt&& other) noexcept { + std::swap(data_, other.data_); + std::swap(size_, other.size_); + return *this; + } }; /** - * @brief A lightweight error class for handling error messages, - * which are expected to be allocated in static memory. - */ +* @brief A lightweight error class for handling error messages, +* which are expected to be allocated in static memory. +*/ class error_t { - char const *message_ {}; + char const* message_{}; public: - error_t(char const *message = nullptr) noexcept : message_(message) { - } - error_t &operator=(char const *message) noexcept { - message_ = message; - return *this; - } - - error_t(error_t const &) = delete; - error_t &operator=(error_t const &) = delete; - error_t(error_t &&other) noexcept : message_(exchange(other.message_, nullptr)) { - } - error_t &operator=(error_t &&other) noexcept { - std::swap(message_, other.message_); - return *this; - } - explicit operator bool() const noexcept { - return message_ != nullptr; - } - char const *what() const noexcept { - return message_; - } - char const *release() noexcept { - return exchange(message_, nullptr); - } + error_t(char const* message = nullptr) noexcept : message_(message) {} + error_t& operator=(char const* message) noexcept { + message_ = message; + return *this; + } + + error_t(error_t const&) = delete; + error_t& operator=(error_t const&) = delete; + error_t(error_t&& other) noexcept : message_(exchange(other.message_, nullptr)) {} + error_t& operator=(error_t&& other) noexcept { + std::swap(message_, other.message_); + return *this; + } + explicit operator bool() const noexcept { return message_ != nullptr; } + char const* what() const noexcept { return message_; } + char const* release() noexcept { return exchange(message_, nullptr); } #if defined(__cpp_exceptions) || defined(__EXCEPTIONS) - ~error_t() noexcept(false) { + ~error_t() noexcept(false) { #if defined(USEARCH_DEFINED_CPP17) - if (message_ && std::uncaught_exceptions() == 0) + if (message_ && std::uncaught_exceptions() == 0) #else - if (message_ && std::uncaught_exception() == 0) + if (message_ && std::uncaught_exception() == 0) #endif - raise(); - } - void raise() noexcept(false) { - if (message_) - throw std::runtime_error(exchange(message_, nullptr)); - } + raise(); + } + void raise() noexcept(false) { + if (message_) + throw std::runtime_error(exchange(message_, nullptr)); + } #else - ~error_t() noexcept { - raise(); - } - void raise() noexcept { - if (message_) - std::terminate(); - } + ~error_t() noexcept { raise(); } + void raise() noexcept { + if (message_) + std::terminate(); + } #endif }; /** - * @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, - * or an error. It's used to avoid raising exception, and gracefully propagate - * the error. - * - * @tparam result_at The type of the expected result. - */ -template -struct expected_gt { - result_at result; - error_t error; - - operator result_at &() & { - error.raise(); - return result; - } - operator result_at &&() && { - error.raise(); - return std::move(result); - } - result_at const &operator*() const noexcept { - return result; - } - explicit operator bool() const noexcept { - return !error; - } - expected_gt failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } +* @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, +* or an error. It's used to avoid raising exception, and gracefully propagate +* the error. +* +* @tparam result_at The type of the expected result. +*/ +template struct expected_gt { + result_at result; + error_t error; + + operator result_at&() & { + error.raise(); + return result; + } + operator result_at&&() && { + error.raise(); + return std::move(result); + } + result_at const& operator*() const noexcept { return result; } + explicit operator bool() const noexcept { return !error; } + expected_gt failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** - * @brief Light-weight bitset implementation to sync nodes updates during graph mutations. - * Extends basic functionality with @b atomic operations. - */ -template > -class bitset_gt { - using allocator_t = allocator_at; - using byte_t = typename allocator_t::value_type; - static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); - - using compressed_slot_t = unsigned long; - - static constexpr std::size_t bits_per_slot() { - return sizeof(compressed_slot_t) * CHAR_BIT; - } - static constexpr compressed_slot_t bits_mask() { - return sizeof(compressed_slot_t) * CHAR_BIT - 1; - } - static constexpr std::size_t slots(std::size_t bits) { - return divide_round_up(bits); - } - - compressed_slot_t *slots_ {}; - /// @brief Number of slots. - std::size_t count_ {}; +* @brief Light-weight bitset implementation to sync nodes updates during graph mutations. +* Extends basic functionality with @b atomic operations. +*/ +template > class bitset_gt { + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + using compressed_slot_t = unsigned long; + + static constexpr std::size_t bits_per_slot() { return sizeof(compressed_slot_t) * CHAR_BIT; } + static constexpr compressed_slot_t bits_mask() { return sizeof(compressed_slot_t) * CHAR_BIT - 1; } + static constexpr std::size_t slots(std::size_t bits) { return divide_round_up(bits); } + + compressed_slot_t* slots_{}; + /// @brief Number of slots. + std::size_t count_{}; public: - bitset_gt() noexcept { - } - ~bitset_gt() noexcept { - reset(); - } - - explicit operator bool() const noexcept { - return slots_; - } - void clear() noexcept { - if (slots_) - std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); - } - - void reset() noexcept { - if (slots_) - allocator_t {}.deallocate((byte_t *)slots_, count_ * sizeof(compressed_slot_t)); - slots_ = nullptr; - count_ = 0; - } - - bitset_gt(std::size_t capacity) noexcept - : slots_((compressed_slot_t *)allocator_t {}.allocate(slots(capacity) * sizeof(compressed_slot_t))), - count_(slots_ ? slots(capacity) : 0u) { - clear(); - } - - bitset_gt(bitset_gt &&other) noexcept { - slots_ = exchange(other.slots_, nullptr); - count_ = exchange(other.count_, 0); - } - - bitset_gt &operator=(bitset_gt &&other) noexcept { - std::swap(slots_, other.slots_); - std::swap(count_, other.count_); - return *this; - } - - bitset_gt(bitset_gt const &) = delete; - bitset_gt &operator=(bitset_gt const &) = delete; - - inline bool test(std::size_t i) const noexcept { - return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); - } - inline bool set(std::size_t i) noexcept { - compressed_slot_t &slot = slots_[i / bits_per_slot()]; - compressed_slot_t mask {1ul << (i & bits_mask())}; - bool value = slot & mask; - slot |= mask; - return value; - } + bitset_gt() noexcept {} + ~bitset_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + void clear() noexcept { + if (slots_) + std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, count_ * sizeof(compressed_slot_t)); + slots_ = nullptr; + count_ = 0; + } + + bitset_gt(std::size_t capacity) noexcept + : slots_((compressed_slot_t*)allocator_t{}.allocate(slots(capacity) * sizeof(compressed_slot_t))), + count_(slots_ ? slots(capacity) : 0u) { + clear(); + } + + bitset_gt(bitset_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + count_ = exchange(other.count_, 0); + } + + bitset_gt& operator=(bitset_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(count_, other.count_); + return *this; + } + + bitset_gt(bitset_gt const&) = delete; + bitset_gt& operator=(bitset_gt const&) = delete; + + inline bool test(std::size_t i) const noexcept { return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); } + inline bool set(std::size_t i) noexcept { + compressed_slot_t& slot = slots_[i / bits_per_slot()]; + compressed_slot_t mask{1ul << (i & bits_mask())}; + bool value = slot & mask; + slot |= mask; + return value; + } #if defined(USEARCH_DEFINED_WINDOWS) - inline bool atomic_set(std::size_t i) noexcept { - compressed_slot_t mask {1ul << (i & bits_mask())}; - return InterlockedOr((long volatile *)&slots_[i / bits_per_slot()], mask) & mask; - } + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return InterlockedOr((long volatile*)&slots_[i / bits_per_slot()], mask) & mask; + } - inline void atomic_reset(std::size_t i) noexcept { - compressed_slot_t mask {1ul << (i & bits_mask())}; - InterlockedAnd((long volatile *)&slots_[i / bits_per_slot()], ~mask); - } + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + InterlockedAnd((long volatile*)&slots_[i / bits_per_slot()], ~mask); + } #else - inline bool atomic_set(std::size_t i) noexcept { - compressed_slot_t mask {1ul << (i & bits_mask())}; - return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; - } + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; + } - inline void atomic_reset(std::size_t i) noexcept { - compressed_slot_t mask {1ul << (i & bits_mask())}; - __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); - } + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); + } #endif - class lock_t { - bitset_gt &bitset_; - std::size_t bit_offset_; - - public: - inline ~lock_t() noexcept { - bitset_.atomic_reset(bit_offset_); - } - inline lock_t(bitset_gt &bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { - while (bitset_.atomic_set(bit_offset_)) - ; - } - }; - - inline lock_t lock(std::size_t i) noexcept { - return {*this, i}; - } + class lock_t { + bitset_gt& bitset_; + std::size_t bit_offset_; + + public: + inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); } + inline lock_t(bitset_gt& bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { + while (bitset_.atomic_set(bit_offset_)) + ; + } + }; + + inline lock_t lock(std::size_t i) noexcept { return {*this, i}; } }; using bitset_t = bitset_gt<>; /** - * @brief Similar to `std::priority_queue`, but allows raw access to underlying - * memory, in case you want to shuffle it or sort. Good for collections - * from 100s to 10'000s elements. - */ +* @brief Similar to `std::priority_queue`, but allows raw access to underlying +* memory, in case you want to shuffle it or sort. Good for collections +* from 100s to 10'000s elements. +*/ template , // is needed before C++14. - typename allocator_at = std::allocator> // + typename comparator_at = std::less, // is needed before C++14. + typename allocator_at = std::allocator> // class max_heap_gt { public: - using element_t = element_at; - using comparator_t = comparator_at; - using allocator_t = allocator_at; + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; - using value_type = element_t; + using value_type = element_t; - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); private: - element_t *elements_; - std::size_t size_; - std::size_t capacity_; + element_t* elements_; + std::size_t size_; + std::size_t capacity_; public: - max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) { - } - - max_heap_gt(max_heap_gt &&other) noexcept - : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), - capacity_(exchange(other.capacity_, 0)) { - } - - max_heap_gt &operator=(max_heap_gt &&other) noexcept { - std::swap(elements_, other.elements_); - std::swap(size_, other.size_); - std::swap(capacity_, other.capacity_); - return *this; - } - - max_heap_gt(max_heap_gt const &) = delete; - max_heap_gt &operator=(max_heap_gt const &) = delete; - - ~max_heap_gt() noexcept { - reset(); - } - - void reset() noexcept { - if (elements_) - allocator_t {}.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - size_ = 0; - } - - inline bool empty() const noexcept { - return !size_; - } - inline std::size_t size() const noexcept { - return size_; - } - inline std::size_t capacity() const noexcept { - return capacity_; - } - - /// @brief Selects the largest element in the heap. - /// @return Reference to the stored element. - inline element_t const &top() const noexcept { - return elements_[0]; - } - inline void clear() noexcept { - size_ = 0; - } - - bool reserve(std::size_t new_capacity) noexcept { - if (new_capacity < capacity_) - return true; - - new_capacity = ceil2(new_capacity); - new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); - auto allocator = allocator_t {}; - auto new_elements = allocator.allocate(new_capacity); - if (!new_elements) - return false; - - if (elements_) { - std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); - allocator.deallocate(elements_, capacity_); - } - elements_ = new_elements; - capacity_ = new_capacity; - return new_elements; - } - - bool insert(element_t &&element) noexcept { - if (!reserve(size_ + 1)) - return false; - - insert_reserved(std::move(element)); - return true; - } - - inline void insert_reserved(element_t &&element) noexcept { - new (&elements_[size_]) element_t(element); - size_++; - shift_up(size_ - 1); - } - - inline element_t pop() noexcept { - element_t result = top(); - std::swap(elements_[0], elements_[size_ - 1]); - size_--; - elements_[size_].~element_t(); - shift_down(0); - return result; - } - - /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ - inline void sort_ascending() noexcept { - std::sort_heap(elements_, elements_ + size_, &less); - } - inline void shrink(std::size_t n) noexcept { - size_ = (std::min)(n, size_); - } - - inline element_t *data() noexcept { - return elements_; - } - inline element_t const *data() const noexcept { - return elements_; - } + max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + max_heap_gt(max_heap_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + max_heap_gt& operator=(max_heap_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + max_heap_gt(max_heap_gt const&) = delete; + max_heap_gt& operator=(max_heap_gt const&) = delete; + + ~max_heap_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + + /// @brief Selects the largest element in the heap. + /// @return Reference to the stored element. + inline element_t const& top() const noexcept { return elements_[0]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (elements_) { + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + allocator.deallocate(elements_, capacity_); + } + elements_ = new_elements; + capacity_ = new_capacity; + return new_elements; + } + + bool insert(element_t&& element) noexcept { + if (!reserve(size_ + 1)) + return false; + + insert_reserved(std::move(element)); + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + new (&elements_[size_]) element_t(element); + size_++; + shift_up(size_ - 1); + } + + inline element_t pop() noexcept { + element_t result = top(); + std::swap(elements_[0], elements_[size_ - 1]); + size_--; + elements_[size_].~element_t(); + shift_down(0); + return result; + } + + /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ + inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } private: - inline std::size_t parent_idx(std::size_t i) const noexcept { - return (i - 1u) / 2u; - } - inline std::size_t left_child_idx(std::size_t i) const noexcept { - return (i * 2u) + 1u; - } - inline std::size_t right_child_idx(std::size_t i) const noexcept { - return (i * 2u) + 2u; - } - static bool less(element_t const &a, element_t const &b) noexcept { - return comparator_t {}(a, b); - } - - void shift_up(std::size_t i) noexcept { - for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) - std::swap(elements_[parent_idx(i)], elements_[i]); - } - - void shift_down(std::size_t i) noexcept { - std::size_t max_idx = i; - - std::size_t left = left_child_idx(i); - if (left < size_ && less(elements_[max_idx], elements_[left])) - max_idx = left; - - std::size_t right = right_child_idx(i); - if (right < size_ && less(elements_[max_idx], elements_[right])) - max_idx = right; - - if (i != max_idx) { - std::swap(elements_[i], elements_[max_idx]); - shift_down(max_idx); - } - } + inline std::size_t parent_idx(std::size_t i) const noexcept { return (i - 1u) / 2u; } + inline std::size_t left_child_idx(std::size_t i) const noexcept { return (i * 2u) + 1u; } + inline std::size_t right_child_idx(std::size_t i) const noexcept { return (i * 2u) + 2u; } + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + + void shift_up(std::size_t i) noexcept { + for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) + std::swap(elements_[parent_idx(i)], elements_[i]); + } + + void shift_down(std::size_t i) noexcept { + std::size_t max_idx = i; + + std::size_t left = left_child_idx(i); + if (left < size_ && less(elements_[max_idx], elements_[left])) + max_idx = left; + + std::size_t right = right_child_idx(i); + if (right < size_ && less(elements_[max_idx], elements_[right])) + max_idx = right; + + if (i != max_idx) { + std::swap(elements_[i], elements_[max_idx]); + shift_down(max_idx); + } + } }; /** - * @brief Similar to `std::priority_queue`, but allows raw access to underlying - * memory and always keeps the data sorted. Ideal for small collections - * under 128 elements. - */ +* @brief Similar to `std::priority_queue`, but allows raw access to underlying +* memory and always keeps the data sorted. Ideal for small collections +* under 128 elements. +*/ template , // is needed before C++14. - typename allocator_at = std::allocator> // + typename comparator_at = std::less, // is needed before C++14. + typename allocator_at = std::allocator> // class sorted_buffer_gt { public: - using element_t = element_at; - using comparator_t = comparator_at; - using allocator_t = allocator_at; + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); - using value_type = element_t; + using value_type = element_t; private: - element_t *elements_; - std::size_t size_; - std::size_t capacity_; + element_t* elements_; + std::size_t size_; + std::size_t capacity_; public: - sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) { - } - - sorted_buffer_gt(sorted_buffer_gt &&other) noexcept - : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), - capacity_(exchange(other.capacity_, 0)) { - } - - sorted_buffer_gt &operator=(sorted_buffer_gt &&other) noexcept { - std::swap(elements_, other.elements_); - std::swap(size_, other.size_); - std::swap(capacity_, other.capacity_); - return *this; - } - - sorted_buffer_gt(sorted_buffer_gt const &) = delete; - sorted_buffer_gt &operator=(sorted_buffer_gt const &) = delete; - - ~sorted_buffer_gt() noexcept { - reset(); - } - - void reset() noexcept { - if (elements_) - allocator_t {}.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - size_ = 0; - } - - inline bool empty() const noexcept { - return !size_; - } - inline std::size_t size() const noexcept { - return size_; - } - inline std::size_t capacity() const noexcept { - return capacity_; - } - inline element_t const &top() const noexcept { - return elements_[size_ - 1]; - } - inline void clear() noexcept { - size_ = 0; - } - - bool reserve(std::size_t new_capacity) noexcept { - if (new_capacity < capacity_) - return true; - - new_capacity = ceil2(new_capacity); - new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); - auto allocator = allocator_t {}; - auto new_elements = allocator.allocate(new_capacity); - if (!new_elements) - return false; - - if (size_) - std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); - if (elements_) - allocator.deallocate(elements_, capacity_); - - elements_ = new_elements; - capacity_ = new_capacity; - return true; - } - - inline void insert_reserved(element_t &&element) noexcept { - std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; - std::size_t to_move = size_ - slot; - element_t *source = elements_ + size_ - 1; - for (; to_move; --to_move, --source) - source[1] = source[0]; - elements_[slot] = element; - size_++; - } - - /** - * @return `true` if the entry was added, `false` if it wasn't relevant enough. - */ - inline bool insert(element_t &&element, std::size_t limit) noexcept { - std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; - if (slot == limit) - return false; - std::size_t to_move = size_ - slot - (size_ == limit); - element_t *source = elements_ + size_ - 1 - (size_ == limit); - for (; to_move; --to_move, --source) - source[1] = source[0]; - elements_[slot] = element; - size_ += size_ != limit; - return true; - } - - inline element_t pop() noexcept { - size_--; - element_t result = elements_[size_]; - elements_[size_].~element_t(); - return result; - } - - void sort_ascending() noexcept { - } - inline void shrink(std::size_t n) noexcept { - size_ = (std::min)(n, size_); - } - - inline element_t *data() noexcept { - return elements_; - } - inline element_t const *data() const noexcept { - return elements_; - } + sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + sorted_buffer_gt(sorted_buffer_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + sorted_buffer_gt& operator=(sorted_buffer_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + sorted_buffer_gt(sorted_buffer_gt const&) = delete; + sorted_buffer_gt& operator=(sorted_buffer_gt const&) = delete; + + ~sorted_buffer_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t const& top() const noexcept { return elements_[size_ - 1]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (size_) + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + if (elements_) + allocator.deallocate(elements_, capacity_); + + elements_ = new_elements; + capacity_ = new_capacity; + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + std::size_t to_move = size_ - slot; + element_t* source = elements_ + size_ - 1; + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_++; + } + + /** + * @return `true` if the entry was added, `false` if it wasn't relevant enough. + */ + inline bool insert(element_t&& element, std::size_t limit) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + if (slot == limit) + return false; + std::size_t to_move = size_ - slot - (size_ == limit); + element_t* source = elements_ + size_ - 1 - (size_ == limit); + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_ += size_ != limit; + return true; + } + + inline element_t pop() noexcept { + size_--; + element_t result = elements_[size_]; + elements_[size_].~element_t(); + return result; + } + + void sort_ascending() noexcept {} + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } private: - static bool less(element_t const &a, element_t const &b) noexcept { - return comparator_t {}(a, b); - } + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } }; #if defined(USEARCH_DEFINED_WINDOWS) @@ -909,63 +758,53 @@ class sorted_buffer_gt { #endif /** - * @brief Five-byte integer type to address node clouds with over 4B entries. - * - * @note Avoid usage in 32bit environment - */ +* @brief Five-byte integer type to address node clouds with over 4B entries. +* +* @note Avoid usage in 32bit environment +*/ class usearch_pack_m uint40_t { - unsigned char octets[5]; + unsigned char octets[5]; - inline uint40_t &broadcast(unsigned char c) { - std::memset(octets, c, 5); - return *this; - } + inline uint40_t& broadcast(unsigned char c) { + std::memset(octets, c, 5); + return *this; + } public: - inline uint40_t() noexcept { - broadcast(0); - } - inline uint40_t(std::uint32_t n) noexcept { - std::memcpy(&octets[1], &n, 4); - } + inline uint40_t() noexcept { broadcast(0); } + inline uint40_t(std::uint32_t n) noexcept { std::memcpy(&octets[1], &n, 4); } #ifdef USEARCH_64BIT_ENV - inline uint40_t(std::uint64_t n) noexcept { - std::memcpy(octets, &n, 5); - } + inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } #endif - uint40_t(uint40_t &&) = default; - uint40_t(uint40_t const &) = default; - uint40_t &operator=(uint40_t &&) = default; - uint40_t &operator=(uint40_t const &) = default; + uint40_t(uint40_t&&) = default; + uint40_t(uint40_t const&) = default; + uint40_t& operator=(uint40_t&&) = default; + uint40_t& operator=(uint40_t const&) = default; #if defined(USEARCH_DEFINED_CLANG) && defined(USEARCH_DEFINED_APPLE) - inline uint40_t(std::size_t n) noexcept { + inline uint40_t(std::size_t n) noexcept { #ifdef USEARCH_64BIT_ENV - std::memcpy(octets, &n, 5); + std::memcpy(octets, &n, 5); #else - std::memcpy(octets, &n, 4); + std::memcpy(octets, &n, 4); #endif - } + } #endif - inline operator std::size_t() const noexcept { - std::size_t result = 0; + inline operator std::size_t() const noexcept { + std::size_t result = 0; #ifdef USEARCH_64BIT_ENV - std::memcpy(&result, octets, 5); + std::memcpy(&result, octets, 5); #else - std::memcpy(&result, octets + 1, 4); + std::memcpy(&result, octets + 1, 4); #endif - return result; - } - - inline static uint40_t max() noexcept { - return uint40_t {}.broadcast(0xFF); - } - inline static uint40_t min() noexcept { - return uint40_t {}.broadcast(0); - } + return result; + } + + inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } + inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } }; #if defined(USEARCH_DEFINED_WINDOWS) @@ -980,900 +819,785 @@ template ::value && !std::is_same::value>::type* = nullptr> key_at default_free_value() { return key_at(); } // clang-format on -template -struct hash_gt { - std::size_t operator()(element_at const &element) const noexcept { - return std::hash {}(element); - } +template struct hash_gt { + std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } }; -template <> -struct hash_gt { - std::size_t operator()(uint40_t const &element) const noexcept { - return std::hash {}(element); - } +template <> struct hash_gt { + std::size_t operator()(uint40_t const& element) const noexcept { return std::hash{}(element); } }; /** - * @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. - * - * It doesn't support deletion of separate objects, but supports `clear`-ing all at once. - * It expects `reserve` to be called ahead of all insertions, so no resizes are needed. - * It also assumes `0xFF...FF` slots to be unused, to simplify the design. - * It uses linear probing, the number of slots is always a power of two, and it uses linear-probing - * in case of bucket collisions. - */ +* @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. +* +* It doesn't support deletion of separate objects, but supports `clear`-ing all at once. +* It expects `reserve` to be called ahead of all insertions, so no resizes are needed. +* It also assumes `0xFF...FF` slots to be unused, to simplify the design. +* It uses linear probing, the number of slots is always a power of two, and it uses linear-probing +* in case of bucket collisions. +*/ template , typename allocator_at = std::allocator> class growing_hash_set_gt { - using element_t = element_at; - using hasher_t = hasher_at; + using element_t = element_at; + using hasher_t = hasher_at; - using allocator_t = allocator_at; - using byte_t = typename allocator_t::value_type; - static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); - element_t *slots_ {}; - /// @brief Number of slots. - std::size_t capacity_ {}; - /// @brief Number of populated. - std::size_t count_ {}; - hasher_t hasher_ {}; + element_t* slots_{}; + /// @brief Number of slots. + std::size_t capacity_{}; + /// @brief Number of populated. + std::size_t count_{}; + hasher_t hasher_{}; public: - growing_hash_set_gt() noexcept { - } - ~growing_hash_set_gt() noexcept { - reset(); - } - - explicit operator bool() const noexcept { - return slots_; - } - std::size_t size() const noexcept { - return count_; - } - - void clear() noexcept { - if (slots_) - std::memset((void *)slots_, 0xFF, capacity_ * sizeof(element_t)); - count_ = 0; - } - - void reset() noexcept { - if (slots_) - allocator_t {}.deallocate((byte_t *)slots_, capacity_ * sizeof(element_t)); - slots_ = nullptr; - capacity_ = 0; - count_ = 0; - } - - growing_hash_set_gt(std::size_t capacity) noexcept - : slots_((element_t *)allocator_t {}.allocate(ceil2(capacity) * sizeof(element_t))), - capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { - clear(); - } - - growing_hash_set_gt(growing_hash_set_gt &&other) noexcept { - slots_ = exchange(other.slots_, nullptr); - capacity_ = exchange(other.capacity_, 0); - count_ = exchange(other.count_, 0); - } - - growing_hash_set_gt &operator=(growing_hash_set_gt &&other) noexcept { - std::swap(slots_, other.slots_); - std::swap(capacity_, other.capacity_); - std::swap(count_, other.count_); - return *this; - } - - growing_hash_set_gt(growing_hash_set_gt const &) = delete; - growing_hash_set_gt &operator=(growing_hash_set_gt const &) = delete; - - inline bool test(element_t const &elem) const noexcept { - std::size_t index = hasher_(elem) & (capacity_ - 1); - while (slots_[index] != default_free_value()) { - if (slots_[index] == elem) - return true; - - index = (index + 1) & (capacity_ - 1); - } - return false; - } - - /** - * - * @return Similar to `bitset_gt`, returns the previous value. - */ - inline bool set(element_t const &elem) noexcept { - std::size_t index = hasher_(elem) & (capacity_ - 1); - while (slots_[index] != default_free_value()) { - // Already exists - if (slots_[index] == elem) - return true; - - index = (index + 1) & (capacity_ - 1); - } - slots_[index] = elem; - ++count_; - return false; - } - - bool reserve(std::size_t new_capacity) noexcept { - new_capacity = (new_capacity * 5u) / 3u; - if (new_capacity <= capacity_) - return true; - - new_capacity = ceil2(new_capacity); - element_t *new_slots = (element_t *)allocator_t {}.allocate(new_capacity * sizeof(element_t)); - if (!new_slots) - return false; - - std::memset((void *)new_slots, 0xFF, new_capacity * sizeof(element_t)); - std::size_t new_count = count_; - if (count_) { - for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { - if (slots_[old_index] == default_free_value()) - continue; - - std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); - while (new_slots[new_index] != default_free_value()) - new_index = (new_index + 1) & (new_capacity - 1); - new_slots[new_index] = slots_[old_index]; - } - } - - reset(); - slots_ = new_slots; - capacity_ = new_capacity; - count_ = new_count; - return true; - } + growing_hash_set_gt() noexcept {} + ~growing_hash_set_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + std::size_t size() const noexcept { return count_; } + + void clear() noexcept { + if (slots_) + std::memset((void*)slots_, 0xFF, capacity_ * sizeof(element_t)); + count_ = 0; + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, capacity_ * sizeof(element_t)); + slots_ = nullptr; + capacity_ = 0; + count_ = 0; + } + + growing_hash_set_gt(std::size_t capacity) noexcept + : slots_((element_t*)allocator_t{}.allocate(ceil2(capacity) * sizeof(element_t))), + capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { + clear(); + } + + growing_hash_set_gt(growing_hash_set_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + capacity_ = exchange(other.capacity_, 0); + count_ = exchange(other.count_, 0); + } + + growing_hash_set_gt& operator=(growing_hash_set_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(capacity_, other.capacity_); + std::swap(count_, other.count_); + return *this; + } + + growing_hash_set_gt(growing_hash_set_gt const&) = delete; + growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; + + inline bool test(element_t const& elem) const noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + return false; + } + + /** + * + * @return Similar to `bitset_gt`, returns the previous value. + */ + inline bool set(element_t const& elem) noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + // Already exists + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + slots_[index] = elem; + ++count_; + return false; + } + + bool reserve(std::size_t new_capacity) noexcept { + new_capacity = (new_capacity * 5u) / 3u; + if (new_capacity <= capacity_) + return true; + + new_capacity = ceil2(new_capacity); + element_t* new_slots = (element_t*)allocator_t{}.allocate(new_capacity * sizeof(element_t)); + if (!new_slots) + return false; + + std::memset((void*)new_slots, 0xFF, new_capacity * sizeof(element_t)); + std::size_t new_count = count_; + if (count_) { + for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { + if (slots_[old_index] == default_free_value()) + continue; + + std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); + while (new_slots[new_index] != default_free_value()) + new_index = (new_index + 1) & (new_capacity - 1); + new_slots[new_index] = slots_[old_index]; + } + } + + reset(); + slots_ = new_slots; + capacity_ = new_capacity; + count_ = new_count; + return true; + } }; /** - * @brief Basic single-threaded @b ring class, used for all kinds of task queues. - */ +* @brief Basic single-threaded @b ring class, used for all kinds of task queues. +*/ template > // class ring_gt { public: - using element_t = element_at; - using allocator_t = allocator_at; + using element_t = element_at; + using allocator_t = allocator_at; - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); - using value_type = element_t; + using value_type = element_t; private: - element_t *elements_ {}; - std::size_t capacity_ {}; - std::size_t head_ {}; - std::size_t tail_ {}; - bool empty_ {true}; - allocator_t allocator_ {}; + element_t* elements_{}; + std::size_t capacity_{}; + std::size_t head_{}; + std::size_t tail_{}; + bool empty_{true}; + allocator_t allocator_{}; public: - explicit ring_gt(allocator_t const &alloc = allocator_t()) noexcept : allocator_(alloc) { - } - - ring_gt(ring_gt const &) = delete; - ring_gt &operator=(ring_gt const &) = delete; - - ring_gt(ring_gt &&other) noexcept { - swap(other); - } - ring_gt &operator=(ring_gt &&other) noexcept { - swap(other); - return *this; - } - - void swap(ring_gt &other) noexcept { - std::swap(elements_, other.elements_); - std::swap(capacity_, other.capacity_); - std::swap(head_, other.head_); - std::swap(tail_, other.tail_); - std::swap(empty_, other.empty_); - std::swap(allocator_, other.allocator_); - } - - ~ring_gt() noexcept { - reset(); - } - - bool empty() const noexcept { - return empty_; - } - size_t capacity() const noexcept { - return capacity_; - } - size_t size() const noexcept { - if (empty_) - return 0; - else if (head_ >= tail_) - return head_ - tail_; - else - return capacity_ - (tail_ - head_); - } - - void clear() noexcept { - head_ = 0; - tail_ = 0; - empty_ = true; - } - - void reset() noexcept { - if (elements_) - allocator_.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - head_ = 0; - tail_ = 0; - empty_ = true; - } - - bool reserve(std::size_t n) noexcept { - if (n < size()) - return false; // prevent data loss - if (n <= capacity()) - return true; - n = (std::max)(ceil2(n), 64u); - element_t *elements = allocator_.allocate(n); - if (!elements) - return false; - - std::size_t i = 0; - while (try_pop(elements[i])) - i++; - - reset(); - elements_ = elements; - capacity_ = n; - head_ = i; - tail_ = 0; - empty_ = (i == 0); - return true; - } - - void push(element_t const &value) noexcept { - elements_[head_] = value; - head_ = (head_ + 1) % capacity_; - empty_ = false; - } - - bool try_push(element_t const &value) noexcept { - if (head_ == tail_ && !empty_) - return false; // elements_ is full - - return push(value); - return true; - } - - bool try_pop(element_t &value) noexcept { - if (empty_) - return false; - - value = std::move(elements_[tail_]); - tail_ = (tail_ + 1) % capacity_; - empty_ = head_ == tail_; - return true; - } - - element_t const &operator[](std::size_t i) const noexcept { - return elements_[(tail_ + i) % capacity_]; - } + explicit ring_gt(allocator_t const& alloc = allocator_t()) noexcept : allocator_(alloc) {} + + ring_gt(ring_gt const&) = delete; + ring_gt& operator=(ring_gt const&) = delete; + + ring_gt(ring_gt&& other) noexcept { swap(other); } + ring_gt& operator=(ring_gt&& other) noexcept { + swap(other); + return *this; + } + + void swap(ring_gt& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(capacity_, other.capacity_); + std::swap(head_, other.head_); + std::swap(tail_, other.tail_); + std::swap(empty_, other.empty_); + std::swap(allocator_, other.allocator_); + } + + ~ring_gt() noexcept { reset(); } + + bool empty() const noexcept { return empty_; } + size_t capacity() const noexcept { return capacity_; } + size_t size() const noexcept { + if (empty_) + return 0; + else if (head_ >= tail_) + return head_ - tail_; + else + return capacity_ - (tail_ - head_); + } + + void clear() noexcept { + head_ = 0; + tail_ = 0; + empty_ = true; + } + + void reset() noexcept { + if (elements_) + allocator_.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + head_ = 0; + tail_ = 0; + empty_ = true; + } + + bool reserve(std::size_t n) noexcept { + if (n < size()) + return false; // prevent data loss + if (n <= capacity()) + return true; + n = (std::max)(ceil2(n), 64u); + element_t* elements = allocator_.allocate(n); + if (!elements) + return false; + + std::size_t i = 0; + while (try_pop(elements[i])) + i++; + + reset(); + elements_ = elements; + capacity_ = n; + head_ = i; + tail_ = 0; + empty_ = (i == 0); + return true; + } + + void push(element_t const& value) noexcept { + elements_[head_] = value; + head_ = (head_ + 1) % capacity_; + empty_ = false; + } + + bool try_push(element_t const& value) noexcept { + if (head_ == tail_ && !empty_) + return false; // elements_ is full + + return push(value); + return true; + } + + bool try_pop(element_t& value) noexcept { + if (empty_) + return false; + + value = std::move(elements_[tail_]); + tail_ = (tail_ + 1) % capacity_; + empty_ = head_ == tail_; + return true; + } + + element_t const& operator[](std::size_t i) const noexcept { return elements_[(tail_ + i) % capacity_]; } }; /// @brief Number of neighbors per graph node. /// Defaults to 32 in FAISS and 16 in hnswlib. /// > It is called `M` in the paper. -constexpr std::size_t default_connectivity() { - return 16; -} +constexpr std::size_t default_connectivity() { return 16; } /// @brief Hyper-parameter controlling the quality of indexing. /// Defaults to 40 in FAISS and 200 in hnswlib. /// > It is called `efConstruction` in the paper. -constexpr std::size_t default_expansion_add() { - return 128; -} +constexpr std::size_t default_expansion_add() { return 128; } /// @brief Hyper-parameter controlling the quality of search. /// Defaults to 16 in FAISS and 10 in hnswlib. /// > It is called `ef` in the paper. -constexpr std::size_t default_expansion_search() { - return 64; -} +constexpr std::size_t default_expansion_search() { return 64; } -constexpr std::size_t default_allocator_entry_bytes() { - return 64; -} +constexpr std::size_t default_allocator_entry_bytes() { return 64; } /** - * @brief Configuration settings for the index construction. - * Includes the main `::connectivity` parameter (`M` in the paper) - * and two expansion factors - for construction and search. - */ +* @brief Configuration settings for the index construction. +* Includes the main `::connectivity` parameter (`M` in the paper) +* and two expansion factors - for construction and search. +*/ struct index_config_t { - /// @brief Number of neighbors per graph node. - /// Defaults to 32 in FAISS and 16 in hnswlib. - /// > It is called `M` in the paper. - std::size_t connectivity = default_connectivity(); - - /// @brief Number of neighbors per graph node in base level graph. - /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. - /// > It is called `M0` in the paper. - std::size_t connectivity_base = default_connectivity() * 2; - - inline index_config_t() = default; - inline index_config_t(std::size_t c) noexcept - : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) { - } - inline index_config_t(std::size_t c, std::size_t cb) noexcept - : connectivity(c), connectivity_base((std::max)(c, cb)) { - } + /// @brief Number of neighbors per graph node. + /// Defaults to 32 in FAISS and 16 in hnswlib. + /// > It is called `M` in the paper. + std::size_t connectivity = default_connectivity(); + + /// @brief Number of neighbors per graph node in base level graph. + /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. + /// > It is called `M0` in the paper. + std::size_t connectivity_base = default_connectivity() * 2; + + inline index_config_t() = default; + inline index_config_t(std::size_t c) noexcept + : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) {} + inline index_config_t(std::size_t c, std::size_t cb) noexcept + : connectivity(c), connectivity_base((std::max)(c, cb)) {} }; struct index_limits_t { - std::size_t members = 0; - std::size_t threads_add = std::thread::hardware_concurrency(); - std::size_t threads_search = std::thread::hardware_concurrency(); - - inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) { - } - inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) { - } - inline std::size_t threads() const noexcept { - return (std::max)(threads_add, threads_search); - } - inline std::size_t concurrency() const noexcept { - return (std::min)(threads_add, threads_search); - } + std::size_t members = 0; + std::size_t threads_add = std::thread::hardware_concurrency(); + std::size_t threads_search = std::thread::hardware_concurrency(); + + inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} + inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} + inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } + inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } }; struct index_update_config_t { - /// @brief Hyper-parameter controlling the quality of indexing. - /// Defaults to 40 in FAISS and 200 in hnswlib. - /// > It is called `efConstruction` in the paper. - std::size_t expansion = default_expansion_add(); + /// @brief Hyper-parameter controlling the quality of indexing. + /// Defaults to 40 in FAISS and 200 in hnswlib. + /// > It is called `efConstruction` in the paper. + std::size_t expansion = default_expansion_add(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; }; struct index_search_config_t { - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; - /// @brief Brute-forces exhaustive search over all entries in the index. - bool exact = false; + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; }; struct index_cluster_config_t { - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; }; struct index_copy_config_t {}; struct index_join_config_t { - /// @brief Controls maximum number of proposals per man during stable marriage. - std::size_t max_proposals = 0; + /// @brief Controls maximum number of proposals per man during stable marriage. + std::size_t max_proposals = 0; - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Brute-forces exhaustive search over all entries in the index. - bool exact = false; + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; }; /// @brief C++17 and newer version deprecate the `std::result_of` template using return_type_gt = #if defined(USEARCH_DEFINED_CPP17) - typename std::invoke_result::type; + typename std::invoke_result::type; #else - typename std::result_of::type; + typename std::result_of::type; #endif /** - * @brief An example of what a USearch-compatible ad-hoc filter would look like. - * - * A similar function object can be passed to search queries to further filter entries - * on their auxiliary properties, such as some categorical keys stored in an external DBMS. - */ +* @brief An example of what a USearch-compatible ad-hoc filter would look like. +* +* A similar function object can be passed to search queries to further filter entries +* on their auxiliary properties, such as some categorical keys stored in an external DBMS. +*/ struct dummy_predicate_t { - template - constexpr bool operator()(member_at &&) const noexcept { - return true; - } + template constexpr bool operator()(member_at&&) const noexcept { return true; } }; /** - * @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. - * - * This kind of callbacks is used when the engine is being updated and you want to patch - * the entries, while their are still under locks - limiting concurrent access and providing - * consistency. - */ +* @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. +* +* This kind of callbacks is used when the engine is being updated and you want to patch +* the entries, while their are still under locks - limiting concurrent access and providing +* consistency. +*/ struct dummy_callback_t { - template - void operator()(member_at &&) const noexcept { - } + template void operator()(member_at&&) const noexcept {} }; /** - * @brief An example of what a USearch-compatible progress-bar should look like. - * - * This is particularly helpful when handling long-running tasks, like serialization, - * saving, and loading from disk, or index-level joins. - * The reporter checks return value to continue or stop the process, `false` means need to stop. - */ +* @brief An example of what a USearch-compatible progress-bar should look like. +* +* This is particularly helpful when handling long-running tasks, like serialization, +* saving, and loading from disk, or index-level joins. +* The reporter checks return value to continue or stop the process, `false` means need to stop. +*/ struct dummy_progress_t { - inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { - return true; - } + inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; } }; /** - * @brief An example of what a USearch-compatible values prefetching mechanism should look like. - * - * USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from - * external memory is very expensive, so we've added a pre-fetching mechanism, that accepts - * multiple objects at once, to cache in RAM ahead of the computation. - * The received iterators support both `get_slot` and `get_key` operations. - * An example usage may look like this: - * - * template - * inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { - * for (; begin != end; ++begin) - * io_uring_prefetch(offset_in_file(get_key(begin))); - * } - */ +* @brief An example of what a USearch-compatible values prefetching mechanism should look like. +* +* USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from +* external memory is very expensive, so we've added a pre-fetching mechanism, that accepts +* multiple objects at once, to cache in RAM ahead of the computation. +* The received iterators support both `get_slot` and `get_key` operations. +* An example usage may look like this: +* +* template +* inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { +* for (; begin != end; ++begin) +* io_uring_prefetch(offset_in_file(get_key(begin))); +* } +*/ struct dummy_prefetch_t { - template - inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { - } + template + inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept {} }; /** - * @brief An example of what a USearch-compatible executor (thread-pool) should look like. - * - * It's expected to have `parallel(callback)` API to schedule one task per thread; - * an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts - * the number of tasks, and somehow schedules them between threads; as well as `size()` to - * determine the number of available threads. - */ +* @brief An example of what a USearch-compatible executor (thread-pool) should look like. +* +* It's expected to have `parallel(callback)` API to schedule one task per thread; +* an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts +* the number of tasks, and somehow schedules them between threads; as well as `size()` to +* determine the number of available threads. +*/ struct dummy_executor_t { - dummy_executor_t() noexcept { - } - std::size_t size() const noexcept { - return 1; - } - - template - void fixed(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept { - for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) - thread_aware_function(0, task_idx); - } - - template - void dynamic(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept { - for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) - if (!thread_aware_function(0, task_idx)) - break; - } - - template - void parallel(thread_aware_function_at &&thread_aware_function) noexcept { - thread_aware_function(0); - } + dummy_executor_t() noexcept {} + std::size_t size() const noexcept { return 1; } + + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + thread_aware_function(0, task_idx); + } + + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + if (!thread_aware_function(0, task_idx)) + break; + } + + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept { + thread_aware_function(0); + } }; /** - * @brief An example of what a USearch-compatible key-to-key mapping should look like. - * - * This is particularly helpful for "Semantic Joins", where we map entries of one collection - * to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, - * this can be passed to minimize memory usage. - */ +* @brief An example of what a USearch-compatible key-to-key mapping should look like. +* +* This is particularly helpful for "Semantic Joins", where we map entries of one collection +* to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, +* this can be passed to minimize memory usage. +*/ struct dummy_key_to_key_mapping_t { - struct member_ref_t { - template - member_ref_t &operator=(key_at &&) noexcept { - return *this; - } - }; - template - member_ref_t operator[](key_at &&) const noexcept { - return {}; - } + struct member_ref_t { + template member_ref_t& operator=(key_at&&) noexcept { return *this; } + }; + template member_ref_t operator[](key_at&&) const noexcept { return {}; } }; /** - * @brief Checks if the provided object has a dummy type, emulating an interface, - * but performing no real computation. - */ -template -static constexpr bool is_dummy() { - using object_t = typename std::remove_all_extents::type; - return std::is_same::type, dummy_predicate_t>::value || // - std::is_same::type, dummy_callback_t>::value || // - std::is_same::type, dummy_progress_t>::value || // - std::is_same::type, dummy_prefetch_t>::value || // - std::is_same::type, dummy_executor_t>::value || // - std::is_same::type, dummy_key_to_key_mapping_t>::value; +* @brief Checks if the provided object has a dummy type, emulating an interface, +* but performing no real computation. +*/ +template static constexpr bool is_dummy() { + using object_t = typename std::remove_all_extents::type; + return std::is_same::type, dummy_predicate_t>::value || // + std::is_same::type, dummy_callback_t>::value || // + std::is_same::type, dummy_progress_t>::value || // + std::is_same::type, dummy_prefetch_t>::value || // + std::is_same::type, dummy_executor_t>::value || // + std::is_same::type, dummy_key_to_key_mapping_t>::value; } -template -struct has_reset_gt { - static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); +template struct has_reset_gt { + static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); }; template struct has_reset_gt { private: - template - static constexpr auto check(at *) -> - typename std::is_same().reset(std::declval()...)), return_at>::type; - template - static constexpr std::false_type check(...); + template + static constexpr auto check(at*) -> + typename std::is_same().reset(std::declval()...)), return_at>::type; + template static constexpr std::false_type check(...); - typedef decltype(check(0)) type; + typedef decltype(check(0)) type; public: - static constexpr bool value = type::value; + static constexpr bool value = type::value; }; /** - * @brief Checks if a certain class has a member function called `reset`. - */ -template -constexpr bool has_reset() { - return has_reset_gt::value; -} +* @brief Checks if a certain class has a member function called `reset`. +*/ +template constexpr bool has_reset() { return has_reset_gt::value; } struct serialization_result_t { - error_t error; - - explicit operator bool() const noexcept { - return !error; - } - serialization_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + error_t error; + + explicit operator bool() const noexcept { return !error; } + serialization_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** - * @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. - * - * This class raises no exceptions and corresponds errors through `serialization_result_t`. - * The class automatically closes the file when the object is destroyed. - */ +* @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. +* +* This class raises no exceptions and corresponds errors through `serialization_result_t`. +* The class automatically closes the file when the object is destroyed. +*/ class output_file_t { - char const *path_ = nullptr; - std::FILE *file_ = nullptr; + char const* path_ = nullptr; + std::FILE* file_ = nullptr; public: - output_file_t(char const *path) noexcept : path_(path) { - } - ~output_file_t() noexcept { - close(); - } - output_file_t(output_file_t &&other) noexcept - : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) { - } - output_file_t &operator=(output_file_t &&other) noexcept { - std::swap(path_, other.path_); - std::swap(file_, other.file_); - return *this; - } - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!file_) - file_ = std::fopen(path_, "wb"); - if (!file_) - return result.failed(std::strerror(errno)); - return result; - } - serialization_result_t write(void const *begin, std::size_t length) noexcept { - serialization_result_t result; - std::size_t written = std::fwrite(begin, length, 1, file_); - if (!written) - return result.failed(std::strerror(errno)); - return result; - } - void close() noexcept { - if (file_) - std::fclose(exchange(file_, nullptr)); - } + output_file_t(char const* path) noexcept : path_(path) {} + ~output_file_t() noexcept { close(); } + output_file_t(output_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + output_file_t& operator=(output_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "wb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t write(void const* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t written = std::fwrite(begin, length, 1, file_); + if (length && !written) + return result.failed(std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } }; /** - * @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. - * - * This class raises no exceptions and corresponds errors through `serialization_result_t`. - * The class automatically closes the file when the object is destroyed. - */ +* @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. +* +* This class raises no exceptions and corresponds errors through `serialization_result_t`. +* The class automatically closes the file when the object is destroyed. +*/ class input_file_t { - char const *path_ = nullptr; - std::FILE *file_ = nullptr; + char const* path_ = nullptr; + std::FILE* file_ = nullptr; public: - input_file_t(char const *path) noexcept : path_(path) { - } - ~input_file_t() noexcept { - close(); - } - input_file_t(input_file_t &&other) noexcept - : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) { - } - input_file_t &operator=(input_file_t &&other) noexcept { - std::swap(path_, other.path_); - std::swap(file_, other.file_); - return *this; - } - - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!file_) - file_ = std::fopen(path_, "rb"); - if (!file_) - return result.failed(std::strerror(errno)); - return result; - } - serialization_result_t read(void *begin, std::size_t length) noexcept { - serialization_result_t result; - std::size_t read = std::fread(begin, length, 1, file_); - if (!read) - return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); - return result; - } - void close() noexcept { - if (file_) - std::fclose(exchange(file_, nullptr)); - } - - explicit operator bool() const noexcept { - return file_; - } - bool seek_to(std::size_t progress) noexcept { - return std::fseek(file_, progress, SEEK_SET) == 0; - } - bool seek_to_end() noexcept { - return std::fseek(file_, 0L, SEEK_END) == 0; - } - bool infer_progress(std::size_t &progress) noexcept { - long int result = std::ftell(file_); - if (result == -1L) - return false; - progress = static_cast(result); - return true; - } + input_file_t(char const* path) noexcept : path_(path) {} + ~input_file_t() noexcept { close(); } + input_file_t(input_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + input_file_t& operator=(input_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "rb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t read(void* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t read = std::fread(begin, length, 1, file_); + if (length && !read) + return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } + + explicit operator bool() const noexcept { return file_; } + bool seek_to(std::size_t progress) noexcept { return std::fseek(file_, progress, SEEK_SET) == 0; } + bool seek_to_end() noexcept { return std::fseek(file_, 0L, SEEK_END) == 0; } + bool infer_progress(std::size_t& progress) noexcept { + long int result = std::ftell(file_); + if (result == -1L) + return false; + progress = static_cast(result); + return true; + } }; /** - * @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. - * - * This class provides a convenient way to memory-map a file and access its contents as a block of - * memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. - * The class automatically closes the file when the object is destroyed. - */ +* @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. +* +* This class provides a convenient way to memory-map a file and access its contents as a block of +* memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. +* The class automatically closes the file when the object is destroyed. +*/ class memory_mapped_file_t { - char const *path_ {}; /**< The path to the file to be memory-mapped. */ - void *ptr_ {}; /**< A pointer to the memory-mapping. */ - size_t length_ {}; /**< The length of the memory-mapped file in bytes. */ + char const* path_{}; /**< The path to the file to be memory-mapped. */ + void* ptr_{}; /**< A pointer to the memory-mapping. */ + size_t length_{}; /**< The length of the memory-mapped file in bytes. */ #if defined(USEARCH_DEFINED_WINDOWS) - HANDLE file_handle_ {}; /**< The file handle on Windows. */ - HANDLE mapping_handle_ {}; /**< The mapping handle on Windows. */ + HANDLE file_handle_{}; /**< The file handle on Windows. */ + HANDLE mapping_handle_{}; /**< The mapping handle on Windows. */ #else - int file_descriptor_ {}; /**< The file descriptor on Linux and MacOS. */ + int file_descriptor_{}; /**< The file descriptor on Linux and MacOS. */ #endif public: - explicit operator bool() const noexcept { - return ptr_ != nullptr; - } - byte_t *data() noexcept { - return reinterpret_cast(ptr_); - } - byte_t const *data() const noexcept { - return reinterpret_cast(ptr_); - } - std::size_t size() const noexcept { - return static_cast(length_); - } - - memory_mapped_file_t() noexcept { - } - memory_mapped_file_t(char const *path) noexcept : path_(path) { - } - ~memory_mapped_file_t() noexcept { - close(); - } - memory_mapped_file_t(memory_mapped_file_t &&other) noexcept - : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), - length_(exchange(other.length_, 0)), + explicit operator bool() const noexcept { return ptr_ != nullptr; } + byte_t* data() noexcept { return reinterpret_cast(ptr_); } + byte_t const* data() const noexcept { return reinterpret_cast(ptr_); } + std::size_t size() const noexcept { return static_cast(length_); } + + memory_mapped_file_t() noexcept {} + memory_mapped_file_t(char const* path) noexcept : path_(path) {} + ~memory_mapped_file_t() noexcept { close(); } + memory_mapped_file_t(memory_mapped_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), + length_(exchange(other.length_, 0)), #if defined(USEARCH_DEFINED_WINDOWS) - file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) + file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) #else - file_descriptor_(exchange(other.file_descriptor_, 0)) + file_descriptor_(exchange(other.file_descriptor_, 0)) #endif - { - } + { + } - memory_mapped_file_t(byte_t *data, std::size_t length) noexcept : ptr_(data), length_(length) { - } + memory_mapped_file_t(byte_t* data, std::size_t length) noexcept : ptr_(data), length_(length) {} - memory_mapped_file_t &operator=(memory_mapped_file_t &&other) noexcept { - std::swap(path_, other.path_); - std::swap(ptr_, other.ptr_); - std::swap(length_, other.length_); + memory_mapped_file_t& operator=(memory_mapped_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(ptr_, other.ptr_); + std::swap(length_, other.length_); #if defined(USEARCH_DEFINED_WINDOWS) - std::swap(file_handle_, other.file_handle_); - std::swap(mapping_handle_, other.mapping_handle_); + std::swap(file_handle_, other.file_handle_); + std::swap(mapping_handle_, other.mapping_handle_); #else - std::swap(file_descriptor_, other.file_descriptor_); + std::swap(file_descriptor_, other.file_descriptor_); #endif - return *this; - } + return *this; + } - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!path_ || ptr_) - return result; + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!path_ || ptr_) + return result; #if defined(USEARCH_DEFINED_WINDOWS) - HANDLE file_handle = - CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); - if (file_handle == INVALID_HANDLE_VALUE) - return result.failed("Opening file failed!"); - - std::size_t file_length = GetFileSize(file_handle, 0); - HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); - if (mapping_handle == 0) { - CloseHandle(file_handle); - return result.failed("Mapping file failed!"); - } - - byte_t *file = (byte_t *)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); - if (file == 0) { - CloseHandle(mapping_handle); - CloseHandle(file_handle); - return result.failed("View the map failed!"); - } - file_handle_ = file_handle; - mapping_handle_ = mapping_handle; - ptr_ = file; - length_ = file_length; + HANDLE file_handle = + CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (file_handle == INVALID_HANDLE_VALUE) + return result.failed("Opening file failed!"); + + std::size_t file_length = GetFileSize(file_handle, 0); + HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); + if (mapping_handle == 0) { + CloseHandle(file_handle); + return result.failed("Mapping file failed!"); + } + + byte_t* file = (byte_t*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); + if (file == 0) { + CloseHandle(mapping_handle); + CloseHandle(file_handle); + return result.failed("View the map failed!"); + } + file_handle_ = file_handle; + mapping_handle_ = mapping_handle; + ptr_ = file; + length_ = file_length; #else #if defined(USEARCH_DEFINED_LINUX) - int descriptor = open(path_, O_RDONLY | O_NOATIME); + int descriptor = open(path_, O_RDONLY | O_NOATIME); #else - int descriptor = open(path_, O_RDONLY); + int descriptor = open(path_, O_RDONLY); #endif - if (descriptor < 0) - return result.failed(std::strerror(errno)); - - // Estimate the file size - struct stat file_stat; - int fstat_status = fstat(descriptor, &file_stat); - if (fstat_status < 0) { - ::close(descriptor); - return result.failed(std::strerror(errno)); - } - - // Map the entire file - byte_t *file = (byte_t *)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); - if (file == MAP_FAILED) { - ::close(descriptor); - return result.failed(std::strerror(errno)); - } - file_descriptor_ = descriptor; - ptr_ = file; - length_ = file_stat.st_size; + if (descriptor < 0) + return result.failed(std::strerror(errno)); + + // Estimate the file size + struct stat file_stat; + int fstat_status = fstat(descriptor, &file_stat); + if (fstat_status < 0) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + + // Map the entire file + byte_t* file = (byte_t*)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); + if (file == MAP_FAILED) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + file_descriptor_ = descriptor; + ptr_ = file; + length_ = file_stat.st_size; #endif // Platform specific code - return result; - } - - void close() noexcept { - if (!path_) { - ptr_ = nullptr; - length_ = 0; - return; - } + return result; + } + + void close() noexcept { + if (!path_) { + ptr_ = nullptr; + length_ = 0; + return; + } #if defined(USEARCH_DEFINED_WINDOWS) - UnmapViewOfFile(ptr_); - CloseHandle(mapping_handle_); - CloseHandle(file_handle_); - mapping_handle_ = nullptr; - file_handle_ = nullptr; + UnmapViewOfFile(ptr_); + CloseHandle(mapping_handle_); + CloseHandle(file_handle_); + mapping_handle_ = nullptr; + file_handle_ = nullptr; #else - munmap(ptr_, length_); - ::close(file_descriptor_); - file_descriptor_ = 0; + munmap(ptr_, length_); + ::close(file_descriptor_); + file_descriptor_ = 0; #endif - ptr_ = nullptr; - length_ = 0; - } + ptr_ = nullptr; + length_ = 0; + } }; struct index_serialized_header_t { - std::uint64_t size = 0; - std::uint64_t connectivity = 0; - std::uint64_t connectivity_base = 0; - std::uint64_t max_level = 0; - std::uint64_t entry_slot = 0; + std::uint64_t size = 0; + std::uint64_t connectivity = 0; + std::uint64_t connectivity_base = 0; + std::uint64_t max_level = 0; + std::uint64_t entry_slot = 0; }; using default_key_t = std::uint64_t; using default_slot_t = std::uint32_t; using default_distance_t = float; -template -struct member_gt { - key_at key; - std::size_t slot; +template struct member_gt { + key_at key; + std::size_t slot; }; -template -inline std::size_t get_slot(member_gt const &m) noexcept { - return m.slot; -} -template -inline key_at get_key(member_gt const &m) noexcept { - return m.key; -} +template inline std::size_t get_slot(member_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_gt const& m) noexcept { return m.key; } -template -struct member_cref_gt { - misaligned_ref_gt key; - std::size_t slot; +template struct member_cref_gt { + misaligned_ref_gt key; + std::size_t slot; }; -template -inline std::size_t get_slot(member_cref_gt const &m) noexcept { - return m.slot; -} -template -inline key_at get_key(member_cref_gt const &m) noexcept { - return m.key; -} +template inline std::size_t get_slot(member_cref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_cref_gt const& m) noexcept { return m.key; } -template -struct member_ref_gt { - misaligned_ref_gt key; - std::size_t slot; +template struct member_ref_gt { + misaligned_ref_gt key; + std::size_t slot; - inline operator member_cref_gt() const noexcept { - return {key.ptr(), slot}; - } + inline operator member_cref_gt() const noexcept { return {key.ptr(), slot}; } }; -template -inline std::size_t get_slot(member_ref_gt const &m) noexcept { - return m.slot; -} -template -inline key_at get_key(member_ref_gt const &m) noexcept { - return m.key; -} +template inline std::size_t get_slot(member_ref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_ref_gt const& m) noexcept { return m.key; } /** * @brief Approximate Nearest Neighbors Search @b index-structure using the @@ -1956,2307 +1680,2154 @@ inline key_at get_key(member_ref_gt const &m) noexcept { * */ template , // - typename tape_allocator_at = dynamic_allocator_at> // + typename key_at = default_key_t, // + typename compressed_slot_at = default_slot_t, // + typename dynamic_allocator_at = std::allocator, // + typename tape_allocator_at = dynamic_allocator_at> // class index_gt { public: - using distance_t = distance_at; - using vector_key_t = key_at; - using key_t = vector_key_t; - using compressed_slot_t = compressed_slot_at; - using dynamic_allocator_t = dynamic_allocator_at; - using tape_allocator_t = tape_allocator_at; - static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); - - using member_cref_t = member_cref_gt; - using member_ref_t = member_ref_gt; - - template - class member_iterator_gt { - using ref_t = ref_at; - using index_t = index_at; - - friend class index_gt; - member_iterator_gt() noexcept { - } - member_iterator_gt(index_t *index, std::size_t slot) noexcept : index_(index), slot_(slot) { - } - - index_t *index_ {}; - std::size_t slot_ {}; - - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = ref_t; - using difference_type = std::ptrdiff_t; - using pointer = void; - using reference = ref_t; - - reference operator*() const noexcept { - return {index_->node_at_(slot_).key(), slot_}; - } - vector_key_t key() const noexcept { - return index_->node_at_(slot_).key(); - } - - friend inline std::size_t get_slot(member_iterator_gt const &it) noexcept { - return it.slot_; - } - friend inline vector_key_t get_key(member_iterator_gt const &it) noexcept { - return it.key(); - } - - member_iterator_gt operator++(int) noexcept { - return member_iterator_gt(index_, slot_ + 1); - } - member_iterator_gt operator--(int) noexcept { - return member_iterator_gt(index_, slot_ - 1); - } - member_iterator_gt operator+(difference_type d) noexcept { - return member_iterator_gt(index_, slot_ + d); - } - member_iterator_gt operator-(difference_type d) noexcept { - return member_iterator_gt(index_, slot_ - d); - } - - // clang-format off + using distance_t = distance_at; + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using dynamic_allocator_t = dynamic_allocator_at; + using tape_allocator_t = tape_allocator_at; + static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); + + using member_cref_t = member_cref_gt; + using member_ref_t = member_ref_gt; + + template class member_iterator_gt { + using ref_t = ref_at; + using index_t = index_at; + + friend class index_gt; + member_iterator_gt() noexcept {} + member_iterator_gt(index_t* index, std::size_t slot) noexcept : index_(index), slot_(slot) {} + + index_t* index_{}; + std::size_t slot_{}; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = ref_t; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = ref_t; + + reference operator*() const noexcept { return {index_->node_at_(slot_).key(), slot_}; } + vector_key_t key() const noexcept { return index_->node_at_(slot_).key(); } + + friend inline std::size_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } + friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } + + member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, slot_ + 1); } + member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, slot_ - 1); } + member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, slot_ + d); } + member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, slot_ - d); } + + // clang-format off member_iterator_gt& operator++() noexcept { slot_ += 1; return *this; } member_iterator_gt& operator--() noexcept { slot_ -= 1; return *this; } member_iterator_gt& operator+=(difference_type d) noexcept { slot_ += d; return *this; } member_iterator_gt& operator-=(difference_type d) noexcept { slot_ -= d; return *this; } bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } - // clang-format on - }; - - using member_iterator_t = member_iterator_gt; - using member_citerator_t = member_iterator_gt; - - // STL compatibility: - using value_type = vector_key_t; - using allocator_type = dynamic_allocator_t; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; - using reference = member_ref_t; - using const_reference = member_cref_t; - using pointer = void; - using const_pointer = void; - using iterator = member_iterator_t; - using const_iterator = member_citerator_t; - using reverse_iterator = std::reverse_iterator; - using reverse_const_iterator = std::reverse_iterator; - - using dynamic_allocator_traits_t = std::allocator_traits; - using byte_t = typename dynamic_allocator_t::value_type; - static_assert( // - sizeof(byte_t) == 1, // - "Primary allocator must allocate separate addressable bytes"); - - using tape_allocator_traits_t = std::allocator_traits; - static_assert( // - sizeof(typename tape_allocator_traits_t::value_type) == 1, // - "Tape allocator must allocate separate addressable bytes"); + // clang-format on + }; + + using member_iterator_t = member_iterator_gt; + using member_citerator_t = member_iterator_gt; + + // STL compatibility: + using value_type = vector_key_t; + using allocator_type = dynamic_allocator_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = member_ref_t; + using const_reference = member_cref_t; + using pointer = void; + using const_pointer = void; + using iterator = member_iterator_t; + using const_iterator = member_citerator_t; + using reverse_iterator = std::reverse_iterator; + using reverse_const_iterator = std::reverse_iterator; + + using dynamic_allocator_traits_t = std::allocator_traits; + using byte_t = typename dynamic_allocator_t::value_type; + static_assert( // + sizeof(byte_t) == 1, // + "Primary allocator must allocate separate addressable bytes"); + + using tape_allocator_traits_t = std::allocator_traits; + static_assert( // + sizeof(typename tape_allocator_traits_t::value_type) == 1, // + "Tape allocator must allocate separate addressable bytes"); private: - /** - * @brief Integer for the number of node neighbors at a specific level of the - * multi-level graph. It's selected to be `std::uint32_t` to improve the - * alignment in most common cases. - */ - using neighbors_count_t = std::uint32_t; - using level_t = std::int16_t; - - /** - * @brief How many bytes of memory are needed to form the "head" of the node. - */ - static constexpr std::size_t node_head_bytes_() { - return sizeof(vector_key_t) + sizeof(level_t); - } - - using nodes_mutexes_t = bitset_gt; - - using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; - - struct precomputed_constants_t { - double inverse_log_connectivity {}; - std::size_t neighbors_bytes {}; - std::size_t neighbors_base_bytes {}; - }; - /// @brief A space-efficient internal data-structure used in graph traversal queues. - struct candidate_t { - distance_t distance; - compressed_slot_t slot; - inline bool operator<(candidate_t other) const noexcept { - return distance < other.distance; - } - }; - - using candidates_view_t = span_gt; - using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; - using next_candidates_t = max_heap_gt, candidates_allocator_t>; - - /** - * @brief A loosely-structured handle for every node. One such node is created for every member. - * To minimize memory usage and maximize the number of entries per cache-line, it only - * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then - * a `level_t` for the number of graph @b levels in which this member appears, - * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences - * for @b each-level. - */ - class node_t { - byte_t *tape_ {}; - - public: - explicit node_t(byte_t *tape) noexcept : tape_(tape) { - } - byte_t *tape() const noexcept { - return tape_; - } - byte_t *neighbors_tape() const noexcept { - return tape_ + node_head_bytes_(); - } - explicit operator bool() const noexcept { - return tape_; - } - - node_t() = default; - node_t(node_t const &) = default; - node_t &operator=(node_t const &) = default; - - misaligned_ref_gt ckey() const noexcept { - return {tape_}; - } - misaligned_ref_gt key() const noexcept { - return {tape_}; - } - misaligned_ref_gt level() const noexcept { - return {tape_ + sizeof(vector_key_t)}; - } - - void key(vector_key_t v) noexcept { - return misaligned_store(tape_, v); - } - void level(level_t v) noexcept { - return misaligned_store(tape_ + sizeof(vector_key_t), v); - } - }; - - static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); - static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); - - /** - * @brief A slice of the node's tape, containing a the list of neighbors - * for a node in a single graph level. It's pre-allocated to fit - * as many neighbors "slots", as may be needed at the target level, - * and starts with a single integer `neighbors_count_t` counter. - */ - class neighbors_ref_t { - byte_t *tape_; - - static constexpr std::size_t shift(std::size_t i = 0) { - return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; - } - - public: - neighbors_ref_t(byte_t *tape) noexcept : tape_(tape) { - } - misaligned_ptr_gt begin() noexcept { - return tape_ + shift(); - } - misaligned_ptr_gt end() noexcept { - return begin() + size(); - } - misaligned_ptr_gt begin() const noexcept { - return tape_ + shift(); - } - misaligned_ptr_gt end() const noexcept { - return begin() + size(); - } - compressed_slot_t operator[](std::size_t i) const noexcept { - return misaligned_load(tape_ + shift(i)); - } - std::size_t size() const noexcept { - return misaligned_load(tape_); - } - void clear() noexcept { - neighbors_count_t n = misaligned_load(tape_); - std::memset(tape_, 0, shift(n)); - // misaligned_store(tape_, 0); - } - void push_back(compressed_slot_t slot) noexcept { - neighbors_count_t n = misaligned_load(tape_); - misaligned_store(tape_ + shift(n), slot); - misaligned_store(tape_, n + 1); - } - }; - - /** - * @brief A package of all kinds of temporary data-structures, that the threads - * would reuse to process requests. Similar to having all of those as - * separate `thread_local` global variables. - */ - struct usearch_align_m context_t { - top_candidates_t top_candidates {}; - next_candidates_t next_candidates {}; - visits_hash_set_t visits {}; - std::default_random_engine level_generator {}; - std::size_t iteration_cycles {}; - std::size_t computed_distances_count {}; - - template // - inline distance_t measure(value_at const &first, entry_at const &second, metric_at &&metric) noexcept { - static_assert( // - std::is_same::value || std::is_same::value, - "Unexpected type"); - - computed_distances_count++; - return metric(first, second); - } - - template // - inline distance_t measure(entry_at const &first, entry_at const &second, metric_at &&metric) noexcept { - static_assert( // - std::is_same::value || std::is_same::value, - "Unexpected type"); - - computed_distances_count++; - return metric(first, second); - } - }; - - index_config_t config_ {}; - index_limits_t limits_ {}; - - mutable dynamic_allocator_t dynamic_allocator_ {}; - tape_allocator_t tape_allocator_ {}; - - precomputed_constants_t pre_ {}; - memory_mapped_file_t viewed_file_ {}; - - /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. - usearch_align_m mutable std::atomic nodes_capacity_ {}; - - /// @brief Number of "slots" already storing non-null nodes. - usearch_align_m mutable std::atomic nodes_count_ {}; - - /// @brief Controls access to `max_level_` and `entry_slot_`. - /// If any thread is updating those values, no other threads can `add()` or `search()`. - std::mutex global_mutex_ {}; - - /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. - level_t max_level_ {}; - - /// @brief The slot in which the only node of the top-level graph is stored. - std::size_t entry_slot_ {}; - - using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - /// @brief C-style array of `node_t` smart-pointers. - buffer_gt nodes_ {}; - - /// @brief Mutex, that limits concurrent access to `nodes_`. - mutable nodes_mutexes_t nodes_mutexes_ {}; - - using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - /// @brief Array of thread-specific buffers for temporary data. - mutable buffer_gt contexts_ {}; + /** + * @brief Integer for the number of node neighbors at a specific level of the + * multi-level graph. It's selected to be `std::uint32_t` to improve the + * alignment in most common cases. + */ + using neighbors_count_t = std::uint32_t; + using level_t = std::int16_t; + + /** + * @brief How many bytes of memory are needed to form the "head" of the node. + */ + static constexpr std::size_t node_head_bytes_() { return sizeof(vector_key_t) + sizeof(level_t); } + + using nodes_mutexes_t = bitset_gt; + + using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; + + struct precomputed_constants_t { + double inverse_log_connectivity{}; + std::size_t neighbors_bytes{}; + std::size_t neighbors_base_bytes{}; + }; + /// @brief A space-efficient internal data-structure used in graph traversal queues. + struct candidate_t { + distance_t distance; + compressed_slot_t slot; + inline bool operator<(candidate_t other) const noexcept { return distance < other.distance; } + }; + + using candidates_view_t = span_gt; + using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; + using next_candidates_t = max_heap_gt, candidates_allocator_t>; + + /** + * @brief A loosely-structured handle for every node. One such node is created for every member. + * To minimize memory usage and maximize the number of entries per cache-line, it only + * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then + * a `level_t` for the number of graph @b levels in which this member appears, + * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences + * for @b each-level. + */ + class node_t { + byte_t* tape_{}; + + public: + explicit node_t(byte_t* tape) noexcept : tape_(tape) {} + byte_t* tape() const noexcept { return tape_; } + byte_t* neighbors_tape() const noexcept { return tape_ + node_head_bytes_(); } + explicit operator bool() const noexcept { return tape_; } + + node_t() = default; + node_t(node_t const&) = default; + node_t& operator=(node_t const&) = default; + + misaligned_ref_gt ckey() const noexcept { return {tape_}; } + misaligned_ref_gt key() const noexcept { return {tape_}; } + misaligned_ref_gt level() const noexcept { return {tape_ + sizeof(vector_key_t)}; } + + void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } + void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } + }; + + static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); + static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); + + /** + * @brief A slice of the node's tape, containing a the list of neighbors + * for a node in a single graph level. It's pre-allocated to fit + * as many neighbors "slots", as may be needed at the target level, + * and starts with a single integer `neighbors_count_t` counter. + */ + class neighbors_ref_t { + byte_t* tape_; + + static constexpr std::size_t shift(std::size_t i = 0) { + return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; + } + + public: + neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} + misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() noexcept { return begin() + size(); } + misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() const noexcept { return begin() + size(); } + compressed_slot_t operator[](std::size_t i) const noexcept { + return misaligned_load(tape_ + shift(i)); + } + std::size_t size() const noexcept { return misaligned_load(tape_); } + void clear() noexcept { + neighbors_count_t n = misaligned_load(tape_); + std::memset(tape_, 0, shift(n)); + // misaligned_store(tape_, 0); + } + void push_back(compressed_slot_t slot) noexcept { + neighbors_count_t n = misaligned_load(tape_); + misaligned_store(tape_ + shift(n), slot); + misaligned_store(tape_, n + 1); + } + }; + + /** + * @brief A package of all kinds of temporary data-structures, that the threads + * would reuse to process requests. Similar to having all of those as + * separate `thread_local` global variables. + */ + struct usearch_align_m context_t { + top_candidates_t top_candidates{}; + next_candidates_t next_candidates{}; + visits_hash_set_t visits{}; + std::default_random_engine level_generator{}; + std::size_t iteration_cycles{}; + std::size_t computed_distances_count{}; + + template // + inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + + template // + inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + }; + + index_config_t config_{}; + index_limits_t limits_{}; + + mutable dynamic_allocator_t dynamic_allocator_{}; + tape_allocator_t tape_allocator_{}; + + precomputed_constants_t pre_{}; + memory_mapped_file_t viewed_file_{}; + + /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. + usearch_align_m mutable std::atomic nodes_capacity_{}; + + /// @brief Number of "slots" already storing non-null nodes. + usearch_align_m mutable std::atomic nodes_count_{}; + + /// @brief Controls access to `max_level_` and `entry_slot_`. + /// If any thread is updating those values, no other threads can `add()` or `search()`. + std::mutex global_mutex_{}; + + /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. + level_t max_level_{}; + + /// @brief The slot in which the only node of the top-level graph is stored. + std::size_t entry_slot_{}; + + using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief C-style array of `node_t` smart-pointers. + buffer_gt nodes_{}; + + /// @brief Mutex, that limits concurrent access to `nodes_`. + mutable nodes_mutexes_t nodes_mutexes_{}; + + using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief Array of thread-specific buffers for temporary data. + mutable buffer_gt contexts_{}; public: - std::size_t connectivity() const noexcept { - return config_.connectivity; - } - std::size_t capacity() const noexcept { - return nodes_capacity_; - } - std::size_t size() const noexcept { - return nodes_count_; - } - std::size_t max_level() const noexcept { - return static_cast(max_level_); - } - index_config_t const &config() const noexcept { - return config_; - } - index_limits_t const &limits() const noexcept { - return limits_; - } - bool is_immutable() const noexcept { - return bool(viewed_file_); - } - - /** - * @section Exceptions - * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. - */ - explicit index_gt( // - index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, - tape_allocator_t tape_allocator = {}) noexcept - : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), - tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), - entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() { - } - - /** - * @brief Clones the structure with the same hyper-parameters, but without contents. - */ - index_gt fork() noexcept { - return index_gt {config_, dynamic_allocator_, tape_allocator_}; - } - - ~index_gt() noexcept { - reset(); - } - - index_gt(index_gt &&other) noexcept { - swap(other); - } - - index_gt &operator=(index_gt &&other) noexcept { - swap(other); - return *this; - } - - struct copy_result_t { - error_t error; - index_gt index; - - explicit operator bool() const noexcept { - return !error; - } - copy_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - copy_result_t copy(index_copy_config_t config = {}) const noexcept { - copy_result_t result; - index_gt &other = result.index; - other = index_gt(config_, dynamic_allocator_, tape_allocator_); - if (!other.reserve(limits_)) - return result.failed("Failed to reserve the contexts"); - - // Now all is left - is to allocate new `node_t` instances and populate - // the `other.nodes_` array into it. - for (std::size_t i = 0; i != nodes_count_; ++i) - other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); - - other.nodes_count_ = nodes_count_.load(); - other.max_level_ = max_level_; - other.entry_slot_ = entry_slot_; - - // This controls nothing for now :) - (void)config; - return result; - } - - member_citerator_t cbegin() const noexcept { - return {this, 0}; - } - member_citerator_t cend() const noexcept { - return {this, size()}; - } - member_citerator_t begin() const noexcept { - return {this, 0}; - } - member_citerator_t end() const noexcept { - return {this, size()}; - } - member_iterator_t begin() noexcept { - return {this, 0}; - } - member_iterator_t end() noexcept { - return {this, size()}; - } - - member_ref_t at(std::size_t slot) noexcept { - return {nodes_[slot].key(), slot}; - } - member_cref_t at(std::size_t slot) const noexcept { - return {nodes_[slot].ckey(), slot}; - } - member_iterator_t iterator_at(std::size_t slot) noexcept { - return {this, slot}; - } - member_citerator_t citerator_at(std::size_t slot) const noexcept { - return {this, slot}; - } - - dynamic_allocator_t const &dynamic_allocator() const noexcept { - return dynamic_allocator_; - } - tape_allocator_t const &tape_allocator() const noexcept { - return tape_allocator_; - } + std::size_t connectivity() const noexcept { return config_.connectivity; } + std::size_t capacity() const noexcept { return nodes_capacity_; } + std::size_t size() const noexcept { return nodes_count_; } + std::size_t max_level() const noexcept { return nodes_count_ ? static_cast(max_level_) : 0; } + index_config_t const& config() const noexcept { return config_; } + index_limits_t const& limits() const noexcept { return limits_; } + bool is_immutable() const noexcept { return bool(viewed_file_); } + + /** + * @section Exceptions + * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. + */ + explicit index_gt( // + index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept + : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), + tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), + entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + + /** + * @brief Clones the structure with the same hyper-parameters, but without contents. + */ + index_gt fork() noexcept { return index_gt{config_, dynamic_allocator_, tape_allocator_}; } + + ~index_gt() noexcept { reset(); } + + index_gt(index_gt&& other) noexcept { swap(other); } + + index_gt& operator=(index_gt&& other) noexcept { + swap(other); + return *this; + } + + struct copy_result_t { + error_t error; + index_gt index; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + copy_result_t copy(index_copy_config_t config = {}) const noexcept { + copy_result_t result; + index_gt& other = result.index; + other = index_gt(config_, dynamic_allocator_, tape_allocator_); + if (!other.reserve(limits_)) + return result.failed("Failed to reserve the contexts"); + + // Now all is left - is to allocate new `node_t` instances and populate + // the `other.nodes_` array into it. + for (std::size_t i = 0; i != nodes_count_; ++i) + other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); + + other.nodes_count_ = nodes_count_.load(); + other.max_level_ = max_level_; + other.entry_slot_ = entry_slot_; + + // This controls nothing for now :) + (void)config; + return result; + } + + member_citerator_t cbegin() const noexcept { return {this, 0}; } + member_citerator_t cend() const noexcept { return {this, size()}; } + member_citerator_t begin() const noexcept { return {this, 0}; } + member_citerator_t end() const noexcept { return {this, size()}; } + member_iterator_t begin() noexcept { return {this, 0}; } + member_iterator_t end() noexcept { return {this, size()}; } + + member_ref_t at(std::size_t slot) noexcept { return {nodes_[slot].key(), slot}; } + member_cref_t at(std::size_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } + member_iterator_t iterator_at(std::size_t slot) noexcept { return {this, slot}; } + member_citerator_t citerator_at(std::size_t slot) const noexcept { return {this, slot}; } + + dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } + tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } #pragma region Adjusting Configuration - /** - * @brief Erases all the vectors from the index. - * - * Will change `size()` to zero, but will keep the same `capacity()`. - * Will keep the number of available threads/contexts the same as it was. - */ - void clear() noexcept { - if (!has_reset()) { - std::size_t n = nodes_count_; - for (std::size_t i = 0; i != n; ++i) - node_free_(i); - } else - tape_allocator_.deallocate(nullptr, 0); - nodes_count_ = 0; - max_level_ = -1; - entry_slot_ = 0u; - } - - /** - * @brief Erases all members from index, closing files, and returning RAM to OS. - * - * Will change both `size()` and `capacity()` to zero. - * Will deallocate all threads/contexts. - * If the index is memory-mapped - releases the mapping and the descriptor. - */ - void reset() noexcept { - clear(); - - nodes_ = {}; - contexts_ = {}; - nodes_mutexes_ = {}; - limits_ = index_limits_t {0, 0}; - nodes_capacity_ = 0; - viewed_file_ = memory_mapped_file_t {}; - tape_allocator_ = {}; - } - - /** - * @brief Swaps the underlying memory buffers and thread contexts. - */ - void swap(index_gt &other) noexcept { - std::swap(config_, other.config_); - std::swap(limits_, other.limits_); - std::swap(dynamic_allocator_, other.dynamic_allocator_); - std::swap(tape_allocator_, other.tape_allocator_); - std::swap(pre_, other.pre_); - std::swap(viewed_file_, other.viewed_file_); - std::swap(max_level_, other.max_level_); - std::swap(entry_slot_, other.entry_slot_); - std::swap(nodes_, other.nodes_); - std::swap(nodes_mutexes_, other.nodes_mutexes_); - std::swap(contexts_, other.contexts_); - - // Non-atomic parts. - std::size_t capacity_copy = nodes_capacity_; - std::size_t count_copy = nodes_count_; - nodes_capacity_ = other.nodes_capacity_.load(); - nodes_count_ = other.nodes_count_.load(); - other.nodes_capacity_ = capacity_copy; - other.nodes_count_ = count_copy; - } - - /** - * @brief Increases the `capacity()` of the index to allow adding more vectors. - * @return `true` on success, `false` on memory allocation errors. - */ - bool reserve(index_limits_t limits) usearch_noexcept_m { - - if (limits.threads_add <= limits_.threads_add // - && limits.threads_search <= limits_.threads_search // - && limits.members <= limits_.members) - return true; - - nodes_mutexes_t new_mutexes(limits.members); - buffer_gt new_nodes(limits.members); - buffer_gt new_contexts(limits.threads()); - if (!new_nodes || !new_contexts || !new_mutexes) - return false; - - // Move the nodes info, and deallocate previous buffers. - if (nodes_) - std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); - - limits_ = limits; - nodes_capacity_ = limits.members; - nodes_ = std::move(new_nodes); - contexts_ = std::move(new_contexts); - nodes_mutexes_ = std::move(new_mutexes); - return true; - } + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() noexcept { + if (!has_reset()) { + std::size_t n = nodes_count_; + for (std::size_t i = 0; i != n; ++i) + node_free_(i); + } else + tape_allocator_.deallocate(nullptr, 0); + nodes_count_ = 0; + max_level_ = -1; + entry_slot_ = 0u; + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() noexcept { + clear(); + + nodes_ = {}; + contexts_ = {}; + nodes_mutexes_ = {}; + limits_ = index_limits_t{0, 0}; + nodes_capacity_ = 0; + viewed_file_ = memory_mapped_file_t{}; + tape_allocator_ = {}; + } + + /** + * @brief Swaps the underlying memory buffers and thread contexts. + */ + void swap(index_gt& other) noexcept { + std::swap(config_, other.config_); + std::swap(limits_, other.limits_); + std::swap(dynamic_allocator_, other.dynamic_allocator_); + std::swap(tape_allocator_, other.tape_allocator_); + std::swap(pre_, other.pre_); + std::swap(viewed_file_, other.viewed_file_); + std::swap(max_level_, other.max_level_); + std::swap(entry_slot_, other.entry_slot_); + std::swap(nodes_, other.nodes_); + std::swap(nodes_mutexes_, other.nodes_mutexes_); + std::swap(contexts_, other.contexts_); + + // Non-atomic parts. + std::size_t capacity_copy = nodes_capacity_; + std::size_t count_copy = nodes_count_; + nodes_capacity_ = other.nodes_capacity_.load(); + nodes_count_ = other.nodes_count_.load(); + other.nodes_capacity_ = capacity_copy; + other.nodes_count_ = count_copy; + } + + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @return `true` on success, `false` on memory allocation errors. + */ + bool reserve(index_limits_t limits) usearch_noexcept_m { + + if (limits.threads_add <= limits_.threads_add // + && limits.threads_search <= limits_.threads_search // + && limits.members <= limits_.members) + return true; + + nodes_mutexes_t new_mutexes(limits.members); + buffer_gt new_nodes(limits.members); + buffer_gt new_contexts(limits.threads()); + if (!new_nodes || !new_contexts || !new_mutexes) + return false; + + // Move the nodes info, and deallocate previous buffers. + if (nodes_) + std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); + + limits_ = limits; + nodes_capacity_ = limits.members; + nodes_ = std::move(new_nodes); + contexts_ = std::move(new_contexts); + nodes_mutexes_ = std::move(new_mutexes); + return true; + } #pragma endregion #pragma region Construction and Search - struct add_result_t { - error_t error {}; - std::size_t new_size {}; - std::size_t visited_members {}; - std::size_t computed_distances {}; - std::size_t slot {}; - - explicit operator bool() const noexcept { - return !error; - } - add_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /// @brief Describes a matched search result, augmenting `member_cref_t` - /// contents with `distance` to the query object. - struct match_t { - member_cref_t member; - distance_t distance; - - inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) { - } - - inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) { - } - - inline match_t(match_t &&other) noexcept - : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) { - } - - inline match_t(match_t const &other) noexcept - : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) { - } - - inline match_t &operator=(match_t const &other) noexcept { - member.key.reset(other.member.key.ptr()); - member.slot = other.member.slot; - distance = other.distance; - return *this; - } - - inline match_t &operator=(match_t &&other) noexcept { - member.key.reset(other.member.key.ptr()); - member.slot = other.member.slot; - distance = other.distance; - return *this; - } - }; - - class search_result_t { - node_t const *nodes_ {}; - top_candidates_t const *top_ {}; - - friend class index_gt; - inline search_result_t(index_gt const &index, top_candidates_t &top) noexcept - : nodes_(index.nodes_), top_(&top) { - } - - public: - /** @brief Number of search results found. */ - std::size_t count {}; - /** @brief Number of graph nodes traversed. */ - std::size_t visited_members {}; - /** @brief Number of times the distances were computed. */ - std::size_t computed_distances {}; - error_t error {}; - - inline search_result_t() noexcept { - } - inline search_result_t(search_result_t &&) = default; - inline search_result_t &operator=(search_result_t &&) = default; - - explicit operator bool() const noexcept { - return !error; - } - search_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - - inline operator std::size_t() const noexcept { - return count; - } - inline std::size_t size() const noexcept { - return count; - } - inline bool empty() const noexcept { - return !count; - } - inline match_t operator[](std::size_t i) const noexcept { - return at(i); - } - inline match_t front() const noexcept { - return at(0); - } - inline match_t back() const noexcept { - return at(count - 1); - } - inline bool contains(vector_key_t key) const noexcept { - for (std::size_t i = 0; i != count; ++i) - if (at(i).member.key == key) - return true; - return false; - } - inline match_t at(std::size_t i) const noexcept { - candidate_t const *top_ordered = top_->data(); - candidate_t candidate = top_ordered[i]; - node_t node = nodes_[candidate.slot]; - return {member_cref_t {node.ckey(), candidate.slot}, candidate.distance}; - } - inline std::size_t merge_into( // - vector_key_t *keys, distance_t *distances, // - std::size_t old_count, std::size_t max_count) const noexcept { - - std::size_t merged_count = old_count; - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - distance_t *merged_end = distances + merged_count; - std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; - if (offset == max_count) - continue; - - std::size_t count_worse = merged_count - offset - (max_count == merged_count); - std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); - std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); - keys[offset] = result.member.key; - distances[offset] = result.distance; - merged_count += merged_count != max_count; - } - return merged_count; - } - inline std::size_t dump_to(vector_key_t *keys, distance_t *distances) const noexcept { - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - keys[i] = result.member.key; - distances[i] = result.distance; - } - return count; - } - inline std::size_t dump_to(vector_key_t *keys) const noexcept { - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - keys[i] = result.member.key; - } - return count; - } - }; - - struct cluster_result_t { - error_t error {}; - std::size_t visited_members {}; - std::size_t computed_distances {}; - match_t cluster {}; - - explicit operator bool() const noexcept { - return !error; - } - cluster_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** - * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. - * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. - * - * @tparam metric_at - * A function responsible for computing the distance @b (dis-similarity) between two objects. - * It should be callable into distinctly different scenarios: - * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. - * - `distance_t operator() (entry_at, entry_at)` - between existing entries. - * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. - * - * @param[in] key External identifier/name/descriptor for the new entry. - * @param[in] value Content that will be compared against other entries to index. - * @param[in] metric Callable object measuring distance between ::value and present objects. - * @param[in] config Configuration options for this specific operation. - * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. - */ - template < // - typename value_at, // - typename metric_at, // - typename callback_at = dummy_callback_t, // - typename prefetch_at = dummy_prefetch_t // - > - add_result_t add( // - vector_key_t key, value_at &&value, metric_at &&metric, // - index_update_config_t config = {}, // - callback_at &&callback = callback_at {}, // - prefetch_at &&prefetch = prefetch_at {}) usearch_noexcept_m { - - add_result_t result; - if (is_immutable()) - return result.failed("Can't add to an immutable index"); - - // Make sure we have enough local memory to perform this request - context_t &context = contexts_[config.thread]; - top_candidates_t &top = context.top_candidates; - next_candidates_t &next = context.next_candidates; - top.clear(); - next.clear(); - - // The top list needs one more slot than the connectivity of the base level - // for the heuristic, that tries to squeeze one more element into saturated list. - std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); - std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); - if (!top.reserve(top_limit)) - return result.failed("Out of memory!"); - if (!next.reserve(config.expansion)) - return result.failed("Out of memory!"); - - // Determining how much memory to allocate for the node depends on the target level - std::unique_lock new_level_lock(global_mutex_); - level_t max_level_copy = max_level_; // Copy under lock - std::size_t entry_idx_copy = entry_slot_; // Copy under lock - level_t target_level = choose_random_level_(context.level_generator); - - // Make sure we are not overflowing - std::size_t capacity = nodes_capacity_.load(); - std::size_t new_slot = nodes_count_.fetch_add(1); - if (new_slot >= capacity) { - nodes_count_.fetch_sub(1); - return result.failed("Reserve capacity ahead of insertions!"); - } - - // Allocate the neighbors - node_t node = node_make_(key, target_level); - if (!node) { - nodes_count_.fetch_sub(1); - return result.failed("Out of memory!"); - } - if (target_level <= max_level_copy) - new_level_lock.unlock(); - - nodes_[new_slot] = node; - result.new_size = new_slot + 1; - result.slot = new_slot; - callback(at(new_slot)); - node_lock_t new_lock = node_lock_(new_slot); - - // Do nothing for the first element - if (!new_slot) { - entry_slot_ = new_slot; - max_level_ = target_level; - return result; - } - - // Pull stats - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - connect_node_across_levels_( // - value, metric, prefetch, // - new_slot, entry_idx_copy, max_level_copy, target_level, // - config, context); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - - // Updating the entry point if needed - if (target_level > max_level_copy) { - entry_slot_ = new_slot; - max_level_ = target_level; - } - return result; - } - - /** - * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. - * - * @tparam metric_at - * A function responsible for computing the distance @b (dis-similarity) between two objects. - * It should be callable into distinctly different scenarios: - * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. - * - `distance_t operator() (entry_at, entry_at)` - between existing entries. - * For any possible `entry_at` following interfaces will work: - * - `std::size_t get_slot(entry_at const &)` - * - `vector_key_t get_key(entry_at const &)` - * - * @param[in] iterator Iterator pointing to an existing entry to be replaced. - * @param[in] key External identifier/name/descriptor for the entry. - * @param[in] value Content that will be compared against other entries in the index. - * @param[in] metric Callable object measuring distance between ::value and present objects. - * @param[in] config Configuration options for this specific operation. - * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. - */ - template < // - typename value_at, // - typename metric_at, // - typename callback_at = dummy_callback_t, // - typename prefetch_at = dummy_prefetch_t // - > - add_result_t update( // - member_iterator_t iterator, // - vector_key_t key, // - value_at &&value, // - metric_at &&metric, // - index_update_config_t config = {}, // - callback_at &&callback = callback_at {}, // - prefetch_at &&prefetch = prefetch_at {}) usearch_noexcept_m { - - usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); - add_result_t result; - std::size_t old_slot = iterator.slot_; - - // Make sure we have enough local memory to perform this request - context_t &context = contexts_[config.thread]; - top_candidates_t &top = context.top_candidates; - next_candidates_t &next = context.next_candidates; - top.clear(); - next.clear(); - - // The top list needs one more slot than the connectivity of the base level - // for the heuristic, that tries to squeeze one more element into saturated list. - std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); - std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); - if (!top.reserve(top_limit)) - return result.failed("Out of memory!"); - if (!next.reserve(config.expansion)) - return result.failed("Out of memory!"); - - node_lock_t new_lock = node_lock_(old_slot); - node_t node = node_at_(old_slot); - - level_t node_level = node.level(); - span_bytes_t node_bytes = node_bytes_(node); - std::memset(node_bytes.data(), 0, node_bytes.size()); - node.level(node_level); - - // Pull stats - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - connect_node_across_levels_( // - value, metric, prefetch, // - old_slot, entry_slot_, max_level_, node_level, // - config, context); - node.key(key); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - result.slot = old_slot; - - callback(at(old_slot)); - return result; - } - - /** - * @brief Searches for the closest elements to the given ::query. Thread-safe. - * - * @param[in] query Content that will be compared against other entries in the index. - * @param[in] wanted The upper bound for the number of results to return. - * @param[in] config Configuration options for this specific operation. - * @param[in] predicate Optional filtering predicate for `member_cref_t`. - * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. - */ - template < // - typename value_at, // - typename metric_at, // - typename predicate_at = dummy_predicate_t, // - typename prefetch_at = dummy_prefetch_t // - > - search_result_t search( // - value_at &&query, // - std::size_t wanted, // - metric_at &&metric, // - index_search_config_t config = {}, // - predicate_at &&predicate = predicate_at {}, // - prefetch_at &&prefetch = prefetch_at {}) const noexcept { - - context_t &context = contexts_[config.thread]; - top_candidates_t &top = context.top_candidates; - search_result_t result {*this, top}; - if (!nodes_count_) - return result; - - // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - if (config.exact) { - if (!top.reserve(wanted)) - return result.failed("Out of memory!"); - search_exact_(query, metric, predicate, wanted, context); - } else { - next_candidates_t &next = context.next_candidates; - std::size_t expansion = (std::max)(config.expansion, wanted); - if (!next.reserve(expansion)) - return result.failed("Out of memory!"); - if (!top.reserve(expansion)) - return result.failed("Out of memory!"); - - std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); - - // For bottom layer we need a more optimized procedure - if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) - return result.failed("Out of memory!"); - } - - top.sort_ascending(); - top.shrink(wanted); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - result.count = top.size(); - return result; - } - - /** - * @brief Identifies the closest cluster to the given ::query. Thread-safe. - * - * @param[in] query Content that will be compared against other entries in the index. - * @param[in] level The index level to target. Higher means lower resolution. - * @param[in] config Configuration options for this specific operation. - * @param[in] predicate Optional filtering predicate for `member_cref_t`. - * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. - */ - template < // - typename value_at, // - typename metric_at, // - typename predicate_at = dummy_predicate_t, // - typename prefetch_at = dummy_prefetch_t // - > - cluster_result_t cluster( // - value_at &&query, // - std::size_t level, // - metric_at &&metric, // - index_cluster_config_t config = {}, // - predicate_at &&predicate = predicate_at {}, // - prefetch_at &&prefetch = prefetch_at {}) const noexcept { - - context_t &context = contexts_[config.thread]; - cluster_result_t result; - if (!nodes_count_) - return result.failed("No clusters to identify"); - - // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - next_candidates_t &next = context.next_candidates; - std::size_t expansion = config.expansion; - if (!next.reserve(expansion)) - return result.failed("Out of memory!"); - - result.cluster.member = - at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, level - 1, context)); - result.cluster.distance = context.measure(query, result.cluster.member, metric); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - - (void)predicate; - return result; - } + struct add_result_t { + error_t error{}; + std::size_t new_size{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + std::size_t slot{}; + + explicit operator bool() const noexcept { return !error; } + add_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /// @brief Describes a matched search result, augmenting `member_cref_t` + /// contents with `distance` to the query object. + struct match_t { + member_cref_t member; + distance_t distance; + + inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) {} + + inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) {} + + inline match_t(match_t&& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t(match_t const& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t& operator=(match_t const& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + + inline match_t& operator=(match_t&& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + }; + + class search_result_t { + node_t const* nodes_{}; + top_candidates_t const* top_{}; + + friend class index_gt; + inline search_result_t(index_gt const& index, top_candidates_t& top) noexcept + : nodes_(index.nodes_), top_(&top) {} + + public: + /** @brief Number of search results found. */ + std::size_t count{}; + /** @brief Number of graph nodes traversed. */ + std::size_t visited_members{}; + /** @brief Number of times the distances were computed. */ + std::size_t computed_distances{}; + error_t error{}; + + inline search_result_t() noexcept {} + inline search_result_t(search_result_t&&) = default; + inline search_result_t& operator=(search_result_t&&) = default; + + explicit operator bool() const noexcept { return !error; } + search_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + inline operator std::size_t() const noexcept { return count; } + inline std::size_t size() const noexcept { return count; } + inline bool empty() const noexcept { return !count; } + inline match_t operator[](std::size_t i) const noexcept { return at(i); } + inline match_t front() const noexcept { return at(0); } + inline match_t back() const noexcept { return at(count - 1); } + inline bool contains(vector_key_t key) const noexcept { + for (std::size_t i = 0; i != count; ++i) + if (at(i).member.key == key) + return true; + return false; + } + inline match_t at(std::size_t i) const noexcept { + candidate_t const* top_ordered = top_->data(); + candidate_t candidate = top_ordered[i]; + node_t node = nodes_[candidate.slot]; + return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; + } + inline std::size_t merge_into( // + vector_key_t* keys, distance_t* distances, // + std::size_t old_count, std::size_t max_count) const noexcept { + + std::size_t merged_count = old_count; + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + distance_t* merged_end = distances + merged_count; + std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; + if (offset == max_count) + continue; + + std::size_t count_worse = merged_count - offset - (max_count == merged_count); + std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); + std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); + keys[offset] = result.member.key; + distances[offset] = result.distance; + merged_count += merged_count != max_count; + } + return merged_count; + } + inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + distances[i] = result.distance; + } + return count; + } + inline std::size_t dump_to(vector_key_t* keys) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + } + return count; + } + }; + + struct cluster_result_t { + error_t error{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + match_t cluster{}; + + explicit operator bool() const noexcept { return !error; } + cluster_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. + * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. + * + * @param[in] key External identifier/name/descriptor for the new entry. + * @param[in] value Content that will be compared against other entries to index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t add( // + vector_key_t key, value_at&& value, metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + add_result_t result; + if (is_immutable()) + return result.failed("Can't add to an immutable index"); + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + // Determining how much memory to allocate for the node depends on the target level + std::unique_lock new_level_lock(global_mutex_); + level_t max_level_copy = max_level_; // Copy under lock + std::size_t entry_idx_copy = entry_slot_; // Copy under lock + level_t target_level = choose_random_level_(context.level_generator); + + // Make sure we are not overflowing + std::size_t capacity = nodes_capacity_.load(); + std::size_t new_slot = nodes_count_.fetch_add(1); + if (new_slot >= capacity) { + nodes_count_.fetch_sub(1); + return result.failed("Reserve capacity ahead of insertions!"); + } + + // Allocate the neighbors + node_t node = node_make_(key, target_level); + if (!node) { + nodes_count_.fetch_sub(1); + return result.failed("Out of memory!"); + } + if (target_level <= max_level_copy) + new_level_lock.unlock(); + + nodes_[new_slot] = node; + result.new_size = new_slot + 1; + result.slot = new_slot; + callback(at(new_slot)); + node_lock_t new_lock = node_lock_(new_slot); + + // Do nothing for the first element + if (!new_slot) { + entry_slot_ = new_slot; + max_level_ = target_level; + return result; + } + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + new_slot, entry_idx_copy, max_level_copy, target_level, // + config, context); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + // Updating the entry point if needed + if (target_level > max_level_copy) { + entry_slot_ = new_slot; + max_level_ = target_level; + } + return result; + } + + /** + * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * For any possible `entry_at` following interfaces will work: + * - `std::size_t get_slot(entry_at const &)` + * - `vector_key_t get_key(entry_at const &)` + * + * @param[in] iterator Iterator pointing to an existing entry to be replaced. + * @param[in] key External identifier/name/descriptor for the entry. + * @param[in] value Content that will be compared against other entries in the index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t update( // + member_iterator_t iterator, // + vector_key_t key, // + value_at&& value, // + metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); + add_result_t result; + std::size_t old_slot = iterator.slot_; + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + node_lock_t new_lock = node_lock_(old_slot); + node_t node = node_at_(old_slot); + + level_t node_level = node.level(); + span_bytes_t node_bytes = node_bytes_(node); + std::memset(node_bytes.data(), 0, node_bytes.size()); + node.level(node_level); + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + old_slot, entry_slot_, max_level_, node_level, // + config, context); + node.key(key); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.slot = old_slot; + + callback(at(old_slot)); + return result; + } + + /** + * @brief Searches for the closest elements to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] wanted The upper bound for the number of results to return. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + search_result_t search( // + value_at&& query, // + std::size_t wanted, // + metric_at&& metric, // + index_search_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + search_result_t result{*this, top}; + if (!nodes_count_) + return result; + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + if (config.exact) { + if (!top.reserve(wanted)) + return result.failed("Out of memory!"); + search_exact_(query, metric, predicate, wanted, context); + } else { + next_candidates_t& next = context.next_candidates; + std::size_t expansion = (std::max)(config.expansion, wanted); + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + if (!top.reserve(expansion)) + return result.failed("Out of memory!"); + + std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); + + // For bottom layer we need a more optimized procedure + if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) + return result.failed("Out of memory!"); + } + + top.sort_ascending(); + top.shrink(wanted); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.count = top.size(); + return result; + } + + /** + * @brief Identifies the closest cluster to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] level The index level to target. Higher means lower resolution. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + cluster_result_t cluster( // + value_at&& query, // + std::size_t level, // + metric_at&& metric, // + index_cluster_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + cluster_result_t result; + if (!nodes_count_) + return result.failed("No clusters to identify"); + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + next_candidates_t& next = context.next_candidates; + std::size_t expansion = config.expansion; + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + + result.cluster.member = at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, + static_cast(level - 1), context)); + result.cluster.distance = context.measure(query, result.cluster.member, metric); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + (void)predicate; + return result; + } #pragma endregion #pragma region Metadata - struct stats_t { - std::size_t nodes {}; - std::size_t edges {}; - std::size_t max_edges {}; - std::size_t allocated_bytes {}; - }; - - stats_t stats() const noexcept { - stats_t result {}; - - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; - std::size_t edges = 0; - for (level_t level = 0; level <= node.level(); ++level) - edges += neighbors_(node, level).size(); - - ++result.nodes; - result.allocated_bytes += node_bytes_(node).size(); - result.edges += edges; - result.max_edges += max_edges; - } - return result; - } - - stats_t stats(std::size_t level) const noexcept { - stats_t result {}; - - std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - if (static_cast(node.level()) < level) - continue; - - ++result.nodes; - result.edges += neighbors_(node, level).size(); - result.allocated_bytes += node_head_bytes_() + neighbors_bytes; - } - - std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; - result.max_edges = result.nodes * max_edges_per_node; - return result; - } - - stats_t stats(stats_t *stats_per_level, std::size_t max_level) const noexcept { - - std::size_t head_bytes = node_head_bytes_(); - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - - stats_per_level[0].nodes++; - stats_per_level[0].edges += neighbors_(node, 0).size(); - stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; - - level_t node_level = static_cast(node.level()); - for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { - stats_per_level[l].nodes++; - stats_per_level[l].edges += neighbors_(node, l).size(); - stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; - } - } - - // The `max_edges` parameter can be inferred from `nodes` - stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; - for (std::size_t l = 1; l <= max_level; ++l) - stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; - - // Aggregate stats across levels - stats_t result {}; - for (std::size_t l = 0; l <= max_level; ++l) - result.nodes += stats_per_level[l].nodes, // - result.edges += stats_per_level[l].edges, // - result.allocated_bytes += stats_per_level[l].allocated_bytes, // - result.max_edges += stats_per_level[l].max_edges; // - - return result; - } - - /** - * @brief A relatively accurate lower bound on the amount of memory consumed by the system. - * In practice it's error will be below 10%. - * - * @see `serialized_length` for the length of the binary serialized representation. - */ - std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { - std::size_t total = 0; - if (!viewed_file_) { - stats_t s = stats(); - total += s.allocated_bytes; - total += s.nodes * allocator_entry_bytes; - } - - // Temporary data-structures, proportional to the number of nodes: - total += limits_.members * sizeof(node_t) + allocator_entry_bytes; - - // Temporary data-structures, proportional to the number of threads: - total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; - return total; - } - - std::size_t memory_usage_per_node(level_t level) const noexcept { - return node_bytes_(level); - } + struct stats_t { + std::size_t nodes{}; + std::size_t edges{}; + std::size_t max_edges{}; + std::size_t allocated_bytes{}; + }; + + stats_t stats() const noexcept { + stats_t result{}; + + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; + std::size_t edges = 0; + for (level_t level = 0; level <= node.level(); ++level) + edges += neighbors_(node, level).size(); + + ++result.nodes; + result.allocated_bytes += node_bytes_(node).size(); + result.edges += edges; + result.max_edges += max_edges; + } + return result; + } + + stats_t stats(std::size_t level) const noexcept { + stats_t result{}; + + std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + if (static_cast(node.level()) < level) + continue; + + ++result.nodes; + result.edges += neighbors_(node, level).size(); + result.allocated_bytes += node_head_bytes_() + neighbors_bytes; + } + + std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; + result.max_edges = result.nodes * max_edges_per_node; + return result; + } + + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { + + std::size_t head_bytes = node_head_bytes_(); + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + + stats_per_level[0].nodes++; + stats_per_level[0].edges += neighbors_(node, 0).size(); + stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; + + level_t node_level = static_cast(node.level()); + for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { + stats_per_level[l].nodes++; + stats_per_level[l].edges += neighbors_(node, l).size(); + stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; + } + } + + // The `max_edges` parameter can be inferred from `nodes` + stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; + for (std::size_t l = 1; l <= max_level; ++l) + stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; + + // Aggregate stats across levels + stats_t result{}; + for (std::size_t l = 0; l <= max_level; ++l) + result.nodes += stats_per_level[l].nodes, // + result.edges += stats_per_level[l].edges, // + result.allocated_bytes += stats_per_level[l].allocated_bytes, // + result.max_edges += stats_per_level[l].max_edges; // + + return result; + } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { + std::size_t total = 0; + if (!viewed_file_) { + stats_t s = stats(); + total += s.allocated_bytes; + total += s.nodes * allocator_entry_bytes; + } + + // Temporary data-structures, proportional to the number of nodes: + total += limits_.members * sizeof(node_t) + allocator_entry_bytes; + + // Temporary data-structures, proportional to the number of threads: + total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; + return total; + } + + std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } #pragma endregion #pragma region Serialization - /** - * @brief Estimate the binary length (in bytes) of the serialized index. - */ - std::size_t serialized_length() const noexcept { - std::size_t neighbors_length = 0; - for (std::size_t i = 0; i != size(); ++i) - neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); - return sizeof(index_serialized_header_t) + neighbors_length; - } - - /** - * @brief Saves serialized binary index representation to a stream. - */ - template - serialization_result_t save_to_stream(output_callback_at &&output, progress_at &&progress = {}) const noexcept { - - serialization_result_t result; - - // Export some basic metadata - index_serialized_header_t header; - header.size = nodes_count_; - header.connectivity = config_.connectivity; - header.connectivity_base = config_.connectivity_base; - header.max_level = max_level_; - header.entry_slot = entry_slot_; - if (!output(&header, sizeof(header))) - return result.failed("Failed to serialize the header into stream"); - - // Progress status - std::size_t processed = 0; - std::size_t const total = 2 * header.size; - - // Export the number of levels per node - // That is both enough to estimate the overall memory consumption, - // and to be able to estimate the offsets of every entry in the file. - for (std::size_t i = 0; i != header.size; ++i) { - node_t node = node_at_(i); - level_t level = node.level(); - if (!output(&level, sizeof(level))) - return result.failed("Failed to serialize into stream"); - if (!progress(++processed, total)) - return result.failed("Terminated by user"); - } - - // After that dump the nodes themselves - for (std::size_t i = 0; i != header.size; ++i) { - span_bytes_t node_bytes = node_bytes_(node_at_(i)); - if (!output(node_bytes.data(), node_bytes.size())) - return result.failed("Failed to serialize into stream"); - if (!progress(++processed, total)) - return result.failed("Terminated by user"); - } - - return {}; - } - - /** - * @brief Symmetric to `save_from_stream`, pulls data from a stream. - */ - template - serialization_result_t load_from_stream(input_callback_at &&input, progress_at &&progress = {}) noexcept { - - serialization_result_t result; - - // Remove previously stored objects - reset(); - - // Pull basic metadata - index_serialized_header_t header; - if (!input(&header, sizeof(header))) - return result.failed("Failed to pull the header from the stream"); - - // We are loading an empty index, no more work to do - if (!header.size) { - reset(); - return result; - } - - // Allocate some dynamic memory to read all the levels - using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt levels(header.size); - if (!levels) - return result.failed("Out of memory"); - if (!input(levels, header.size * sizeof(level_t))) - return result.failed("Failed to pull nodes levels from the stream"); - - // Submit metadata - config_.connectivity = header.connectivity; - config_.connectivity_base = header.connectivity_base; - pre_ = precompute_(config_); - index_limits_t limits; - limits.members = header.size; - if (!reserve(limits)) { - reset(); - return result.failed("Out of memory"); - } - nodes_count_ = header.size; - max_level_ = static_cast(header.max_level); - entry_slot_ = static_cast(header.entry_slot); - - // Load the nodes - for (std::size_t i = 0; i != header.size; ++i) { - span_bytes_t node_bytes = node_malloc_(levels[i]); - if (!input(node_bytes.data(), node_bytes.size())) { - reset(); - return result.failed("Failed to pull nodes from the stream"); - } - nodes_[i] = node_t {node_bytes.data()}; - if (!progress(i + 1, header.size)) - return result.failed("Terminated by user"); - } - return {}; - } - - template - serialization_result_t save(char const *file_path, progress_at &&progress = {}) const noexcept { - return save(output_file_t(file_path), std::forward(progress)); - } - - template - serialization_result_t load(char const *file_path, progress_at &&progress = {}) noexcept { - return load(input_file_t(file_path), std::forward(progress)); - } - - /** - * @brief Saves serialized binary index representation to a file, generally on disk. - */ - template - serialization_result_t save(output_file_t file, progress_at &&progress = {}) const noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void *buffer, std::size_t length) { - io_result = file.write(buffer, length); - return !!io_result; - }, - std::forward(progress)); - - if (!stream_result) - return stream_result; - return io_result; - } - - /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, - progress_at &&progress = {}) const noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void *buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(file.data() + offset, buffer, length); - offset += length; - return true; - }, - std::forward(progress)); - - return stream_result; - } - - /** - * @brief Loads the serialized binary index representation from disk to RAM. - * Adjusts the configuration properties of the constructed index to - * match the settings in the file. - */ - template - serialization_result_t load(input_file_t file, progress_at &&progress = {}) noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void *buffer, std::size_t length) { - io_result = file.read(buffer, length); - return !!io_result; - }, - std::forward(progress)); - - if (!stream_result) - return stream_result; - return io_result; - } - - /** - * @brief Loads the serialized binary index representation from disk to RAM. - * Adjusts the configuration properties of the constructed index to - * match the settings in the file. - */ - template - serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, - progress_at &&progress = {}) noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void *buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(buffer, file.data() + offset, length); - offset += length; - return true; - }, - std::forward(progress)); - - return stream_result; - } - - /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, - progress_at &&progress = {}) noexcept { - - // Remove previously stored objects - reset(); - - serialization_result_t result = file.open_if_not(); - if (!result) - return result; - - // Pull basic metadata - index_serialized_header_t header; - if (file.size() - offset < sizeof(header)) - return result.failed("File is corrupted and lacks a header"); - std::memcpy(&header, file.data() + offset, sizeof(header)); - - if (!header.size) { - reset(); - return result; - } - - // Precompute offsets of every node, but before that we need to update the configs - // This could have been done with `std::exclusive_scan`, but it's only available from C++17. - using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt offsets(header.size); - if (!offsets) - return result.failed("Out of memory"); - - config_.connectivity = header.connectivity; - config_.connectivity_base = header.connectivity_base; - pre_ = precompute_(config_); - misaligned_ptr_gt levels {(byte_t *)file.data() + offset + sizeof(header)}; - offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; - for (std::size_t i = 1; i < header.size; ++i) - offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); - - std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); - if (file.size() < total_bytes) { - reset(); - return result.failed("File is corrupted and can't fit all the nodes"); - } - - // Submit metadata and reserve memory - index_limits_t limits; - limits.members = header.size; - if (!reserve(limits)) { - reset(); - return result.failed("Out of memory"); - } - nodes_count_ = header.size; - max_level_ = static_cast(header.max_level); - entry_slot_ = static_cast(header.entry_slot); - - // Rapidly address all the nodes - for (std::size_t i = 0; i != header.size; ++i) { - nodes_[i] = node_t {(byte_t *)file.data() + offsets[i]}; - if (!progress(i + 1, header.size)) - return result.failed("Terminated by user"); - } - viewed_file_ = std::move(file); - return {}; - } + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length() const noexcept { + std::size_t neighbors_length = 0; + for (std::size_t i = 0; i != size(); ++i) + neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); + return sizeof(index_serialized_header_t) + neighbors_length; + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, progress_at&& progress = {}) const noexcept { + + serialization_result_t result; + + // Export some basic metadata + index_serialized_header_t header; + header.size = nodes_count_; + header.connectivity = config_.connectivity; + header.connectivity_base = config_.connectivity_base; + header.max_level = max_level_; + header.entry_slot = entry_slot_; + if (!output(&header, sizeof(header))) + return result.failed("Failed to serialize the header into stream"); + + // Progress status + std::size_t processed = 0; + std::size_t const total = 2 * header.size; + + // Export the number of levels per node + // That is both enough to estimate the overall memory consumption, + // and to be able to estimate the offsets of every entry in the file. + for (std::size_t i = 0; i != header.size; ++i) { + node_t node = node_at_(i); + level_t level = node.level(); + if (!output(&level, sizeof(level))) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + // After that dump the nodes themselves + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_bytes_(node_at_(i)); + if (!output(node_bytes.data(), node_bytes.size())) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + return {}; + } + + /** + * @brief Symmetric to `save_from_stream`, pulls data from a stream. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, progress_at&& progress = {}) noexcept { + + serialization_result_t result; + + // Remove previously stored objects + reset(); + + // Pull basic metadata + index_serialized_header_t header; + if (!input(&header, sizeof(header))) + return result.failed("Failed to pull the header from the stream"); + + // We are loading an empty index, no more work to do + if (!header.size) { + reset(); + return result; + } + + // Allocate some dynamic memory to read all the levels + using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt levels(header.size); + if (!levels) + return result.failed("Out of memory"); + if (!input(levels, header.size * sizeof(level_t))) + return result.failed("Failed to pull nodes levels from the stream"); + + // Submit metadata + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Load the nodes + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_malloc_(levels[i]); + if (!input(node_bytes.data(), node_bytes.size())) { + reset(); + return result.failed("Failed to pull nodes from the stream"); + } + nodes_[i] = node_t{node_bytes.data()}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + return {}; + } + + template + serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept { + return save(output_file_t(file_path), std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept { + return load(input_file_t(file_path), std::forward(progress)); + } + + /** + * @brief Saves serialized binary index representation to a file, generally on disk. + */ + template + serialization_result_t save(output_file_t file, progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(input_file_t file, progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + // Remove previously stored objects + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Pull basic metadata + index_serialized_header_t header; + if (file.size() - offset < sizeof(header)) + return result.failed("File is corrupted and lacks a header"); + std::memcpy(&header, file.data() + offset, sizeof(header)); + + if (!header.size) { + reset(); + return result; + } + + // Precompute offsets of every node, but before that we need to update the configs + // This could have been done with `std::exclusive_scan`, but it's only available from C++17. + using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt offsets(header.size); + if (!offsets) + return result.failed("Out of memory"); + + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; + offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; + for (std::size_t i = 1; i < header.size; ++i) + offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); + + std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); + if (file.size() < total_bytes) { + reset(); + return result.failed("File is corrupted and can't fit all the nodes"); + } + + // Submit metadata and reserve memory + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Rapidly address all the nodes + for (std::size_t i = 0; i != header.size; ++i) { + nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + viewed_file_ = std::move(file); + return {}; + } #pragma endregion - /** - * @brief Performs compaction on the whole HNSW index, purging some entries - * and links to them, while also generating a more efficient mapping, - * putting the more frequently used entries closer together. - * - * - * Scans the whole collection, removing the links leading towards - * banned entries. This essentially isolates some nodes from the rest - * of the graph, while keeping their outgoing links, in case the node - * is structurally relevant and has a crucial role in the index. - * It won't reclaim the memory. - * - * @param[in] allow_member Predicate to mark nodes for isolation. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - */ - template - void compact( // - values_at &&values, // - metric_at &&metric, // - slot_transition_at &&slot_transition, // - - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}, // - prefetch_at &&prefetch = prefetch_at {}) noexcept { - - // Export all the keys, slots, and levels. - // Partition them with the predicate. - // Sort the allowed entries in descending order of their level. - // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). - struct slot_level_t { - compressed_slot_t old_slot; - compressed_slot_t cluster; - level_t level; - }; - using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt slots_and_levels(size()); - - // Progress status - std::atomic do_tasks {true}; - std::atomic processed {0}; - std::size_t const total = 3 * slots_and_levels.size(); - - // For every bottom level node, determine its parent cluster - executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) { - context_t &context = contexts_[thread_idx]; - std::size_t cluster = search_for_one_( // - values[citerator_at(old_slot)], // - metric, prefetch, // - entry_slot_, max_level_, 0, context); - slots_and_levels[old_slot] = { // - static_cast(old_slot), // - static_cast(cluster), // - node_at_(old_slot).level()}; - ++processed; - if (thread_idx == 0) - do_tasks = progress(processed.load(), total); - return do_tasks.load(); - }); - if (!do_tasks.load()) - return; - - // Where the actual permutation happens: - std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const &a, slot_level_t const &b) { - return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; - }); - - using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt old_slot_to_new(slots_and_levels.size()); - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) - old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; - - // Erase all the incoming links - buffer_gt reordered_nodes(slots_and_levels.size()); - tape_allocator_t reordered_tape; - - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { - std::size_t old_slot = slots_and_levels[new_slot].old_slot; - node_t old_node = node_at_(old_slot); - - std::size_t node_bytes = node_bytes_(old_node.level()); - byte_t *new_data = (byte_t *)reordered_tape.allocate(node_bytes); - node_t new_node {new_data}; - std::memcpy(new_data, old_node.tape(), node_bytes); - - for (level_t level = 0; level <= old_node.level(); ++level) - for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) - neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); - - reordered_nodes[new_slot] = new_node; - if (!progress(++processed, total)) - return; - } - - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { - std::size_t old_slot = slots_and_levels[new_slot].old_slot; - slot_transition(node_at_(old_slot).ckey(), // - static_cast(old_slot), // - static_cast(new_slot)); - if (!progress(++processed, total)) - return; - } - - nodes_ = std::move(reordered_nodes); - tape_allocator_ = std::move(reordered_tape); - entry_slot_ = old_slot_to_new[entry_slot_]; - } - - /** - * @brief Scans the whole collection, removing the links leading towards - * banned entries. This essentially isolates some nodes from the rest - * of the graph, while keeping their outgoing links, in case the node - * is structurally relevant and has a crucial role in the index. - * It won't reclaim the memory. - * - * @param[in] allow_member Predicate to mark nodes for isolation. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - */ - template < // - typename allow_member_at = dummy_predicate_t, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > - void isolate( // - allow_member_at &&allow_member, // - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}) noexcept { - - // Progress status - std::atomic do_tasks {true}; - std::atomic processed {0}; - - // Erase all the incoming links - std::size_t nodes_count = size(); - executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { - node_t node = node_at_(node_idx); - for (level_t level = 0; level <= node.level(); ++level) { - neighbors_ref_t neighbors = neighbors_(node, level); - std::size_t old_size = neighbors.size(); - neighbors.clear(); - for (std::size_t i = 0; i != old_size; ++i) { - compressed_slot_t neighbor_slot = neighbors[i]; - node_t neighbor = node_at_(neighbor_slot); - if (allow_member(member_cref_t {neighbor.ckey(), neighbor_slot})) - neighbors.push_back(neighbor_slot); - } - } - ++processed; - if (thread_idx == 0) - do_tasks = progress(processed.load(), nodes_count); - return do_tasks.load(); - }); - - // At the end report the latest numbers, because the reporter thread may be finished earlier - progress(processed.load(), nodes_count); - } + /** + * @brief Performs compaction on the whole HNSW index, purging some entries + * and links to them, while also generating a more efficient mapping, + * putting the more frequently used entries closer together. + * + * + * Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template + void compact( // + values_at&& values, // + metric_at&& metric, // + slot_transition_at&& slot_transition, // + + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}, // + prefetch_at&& prefetch = prefetch_at{}) noexcept { + + // Export all the keys, slots, and levels. + // Partition them with the predicate. + // Sort the allowed entries in descending order of their level. + // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). + struct slot_level_t { + compressed_slot_t old_slot; + compressed_slot_t cluster; + level_t level; + }; + using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt slots_and_levels(size()); + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + std::size_t const total = 3 * slots_and_levels.size(); + + // For every bottom level node, determine its parent cluster + executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) { + context_t& context = contexts_[thread_idx]; + std::size_t cluster = search_for_one_( // + values[citerator_at(old_slot)], // + metric, prefetch, // + entry_slot_, max_level_, 0, context); + slots_and_levels[old_slot] = { // + static_cast(old_slot), // + static_cast(cluster), // + node_at_(old_slot).level()}; + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), total); + return do_tasks.load(); + }); + if (!do_tasks.load()) + return; + + // Where the actual permutation happens: + std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) { + return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; + }); + + using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt old_slot_to_new(slots_and_levels.size()); + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) + old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; + + // Erase all the incoming links + buffer_gt reordered_nodes(slots_and_levels.size()); + tape_allocator_t reordered_tape; + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + node_t old_node = node_at_(old_slot); + + std::size_t node_bytes = node_bytes_(old_node.level()); + byte_t* new_data = (byte_t*)reordered_tape.allocate(node_bytes); + node_t new_node{new_data}; + std::memcpy(new_data, old_node.tape(), node_bytes); + + for (level_t level = 0; level <= old_node.level(); ++level) + for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) + neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); + + reordered_nodes[new_slot] = new_node; + if (!progress(++processed, total)) + return; + } + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + slot_transition(node_at_(old_slot).ckey(), // + static_cast(old_slot), // + static_cast(new_slot)); + if (!progress(++processed, total)) + return; + } + + nodes_ = std::move(reordered_nodes); + tape_allocator_ = std::move(reordered_tape); + entry_slot_ = old_slot_to_new[entry_slot_]; + } + + /** + * @brief Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template < // + typename allow_member_at = dummy_predicate_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + void isolate( // + allow_member_at&& allow_member, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + + // Erase all the incoming links + std::size_t nodes_count = size(); + executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { + node_t node = node_at_(node_idx); + for (level_t level = 0; level <= node.level(); ++level) { + neighbors_ref_t neighbors = neighbors_(node, level); + std::size_t old_size = neighbors.size(); + neighbors.clear(); + for (std::size_t i = 0; i != old_size; ++i) { + compressed_slot_t neighbor_slot = neighbors[i]; + node_t neighbor = node_at_(neighbor_slot); + if (allow_member(member_cref_t{neighbor.ckey(), neighbor_slot})) + neighbors.push_back(neighbor_slot); + } + } + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), nodes_count); + return do_tasks.load(); + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(processed.load(), nodes_count); + } private: - inline static precomputed_constants_t precompute_(index_config_t const &config) noexcept { - precomputed_constants_t pre; - pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); - pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); - pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); - return pre; - } - - using span_bytes_t = span_gt; - - inline span_bytes_t node_bytes_(node_t node) const noexcept { - return {node.tape(), node_bytes_(node.level())}; - } - inline std::size_t node_bytes_(level_t level) const noexcept { - return node_head_bytes_() + node_neighbors_bytes_(level); - } - inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { - return node_neighbors_bytes_(node.level()); - } - inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { - return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; - } - - span_bytes_t node_malloc_(level_t level) noexcept { - std::size_t node_bytes = node_bytes_(level); - byte_t *data = (byte_t *)tape_allocator_.allocate(node_bytes); - return data ? span_bytes_t {data, node_bytes} : span_bytes_t {}; - } - - node_t node_make_(vector_key_t key, level_t level) noexcept { - span_bytes_t node_bytes = node_malloc_(level); - if (!node_bytes) - return {}; - - std::memset(node_bytes.data(), 0, node_bytes.size()); - node_t node {(byte_t *)node_bytes.data()}; - node.key(key); - node.level(level); - return node; - } - - node_t node_make_copy_(span_bytes_t old_bytes) noexcept { - byte_t *data = (byte_t *)tape_allocator_.allocate(old_bytes.size()); - if (!data) - return {}; - std::memcpy(data, old_bytes.data(), old_bytes.size()); - return node_t {data}; - } - - void node_free_(std::size_t idx) noexcept { - if (viewed_file_) - return; - - node_t &node = nodes_[idx]; - tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); - node = node_t {}; - } - - inline node_t node_at_(std::size_t idx) const noexcept { - return nodes_[idx]; - } - inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { - return {node.neighbors_tape()}; - } - - inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { - return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; - } - - inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { - return level ? neighbors_non_base_(node, level) : neighbors_base_(node); - } - - struct node_lock_t { - nodes_mutexes_t &mutexes; - std::size_t slot; - inline ~node_lock_t() noexcept { - mutexes.atomic_reset(slot); - } - }; - - inline node_lock_t node_lock_(std::size_t slot) const noexcept { - while (nodes_mutexes_.atomic_set(slot)) - ; - return {nodes_mutexes_, slot}; - } - - template - void connect_node_across_levels_( // - value_at &&value, metric_at &&metric, prefetch_at &&prefetch, // - std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // - index_update_config_t const &config, context_t &context) usearch_noexcept_m { - - // Go down the level, tracking only the closest match - std::size_t closest_slot = search_for_one_( // - value, metric, prefetch, // - entry_slot, max_level, target_level, context); - - // From `target_level` down perform proper extensive search - for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { - // TODO: Handle out of memory conditions - search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); - closest_slot = connect_new_node_(metric, node_slot, level, context); - reconnect_neighbor_nodes_(metric, node_slot, value, level, context); - } - } - - template - std::size_t connect_new_node_( // - metric_at &&metric, std::size_t new_slot, level_t level, context_t &context) usearch_noexcept_m { - - node_t new_node = node_at_(new_slot); - top_candidates_t &top = context.top_candidates; - - // Outgoing links from `new_slot`: - neighbors_ref_t new_neighbors = neighbors_(new_node, level); - { - usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); - candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); - - for (std::size_t idx = 0; idx != top_view.size(); idx++) { - usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); - usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); - new_neighbors.push_back(top_view[idx].slot); - } - } - - return new_neighbors[0]; - } - - template - void reconnect_neighbor_nodes_( // - metric_at &&metric, std::size_t new_slot, value_at &&value, level_t level, - context_t &context) usearch_noexcept_m { - - node_t new_node = node_at_(new_slot); - top_candidates_t &top = context.top_candidates; - neighbors_ref_t new_neighbors = neighbors_(new_node, level); - - // Reverse links from the neighbors: - std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; - for (compressed_slot_t close_slot : new_neighbors) { - if (close_slot == new_slot) - continue; - node_lock_t close_lock = node_lock_(close_slot); - node_t close_node = node_at_(close_slot); - - neighbors_ref_t close_header = neighbors_(close_node, level); - usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); - usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); - usearch_assert_m(level <= close_node.level(), "Linking to missing level"); - - // If `new_slot` is already present in the neighboring connections of `close_slot` - // then no need to modify any connections or run the heuristics. - if (close_header.size() < connectivity_max) { - close_header.push_back(static_cast(new_slot)); - continue; - } - - // To fit a new connection we need to drop an existing one. - top.clear(); - usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); - top.insert_reserved( - {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); - for (compressed_slot_t successor_slot : close_header) - top.insert_reserved( - {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); - - // Export the results: - close_header.clear(); - candidates_view_t top_view = refine_(metric, connectivity_max, top, context); - for (std::size_t idx = 0; idx != top_view.size(); idx++) - close_header.push_back(top_view[idx].slot); - } - } - - level_t choose_random_level_(std::default_random_engine &level_generator) const noexcept { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; - return (level_t)r; - } - - struct candidates_range_t; - class candidates_iterator_t { - friend struct candidates_range_t; - - index_gt const &index_; - neighbors_ref_t neighbors_; - visits_hash_set_t &visits_; - std::size_t current_; - - candidates_iterator_t &skip_missing() noexcept { - if (!visits_.size()) - return *this; - while (current_ != neighbors_.size()) { - compressed_slot_t neighbor_slot = neighbors_[current_]; - if (visits_.test(neighbor_slot)) - current_++; - else - break; - } - return *this; - } - - public: - using element_t = compressed_slot_t; - using iterator_category = std::forward_iterator_tag; - using value_type = element_t; - using difference_type = std::ptrdiff_t; - using pointer = misaligned_ptr_gt; - using reference = misaligned_ref_gt; - - reference operator*() const noexcept { - return slot(); - } - candidates_iterator_t(index_gt const &index, neighbors_ref_t neighbors, visits_hash_set_t &visits, - std::size_t progress) noexcept - : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) { - } - candidates_iterator_t operator++(int) noexcept { - return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); - } - candidates_iterator_t &operator++() noexcept { - ++current_; - skip_missing(); - return *this; - } - bool operator==(candidates_iterator_t const &other) noexcept { - return current_ == other.current_; - } - bool operator!=(candidates_iterator_t const &other) noexcept { - return current_ != other.current_; - } - - vector_key_t key() const noexcept { - return index_->node_at_(slot()).key(); - } - compressed_slot_t slot() const noexcept { - return neighbors_[current_]; - } - friend inline std::size_t get_slot(candidates_iterator_t const &it) noexcept { - return it.slot(); - } - friend inline vector_key_t get_key(candidates_iterator_t const &it) noexcept { - return it.key(); - } - }; - - struct candidates_range_t { - index_gt const &index; - neighbors_ref_t neighbors; - visits_hash_set_t &visits; - - candidates_iterator_t begin() const noexcept { - return candidates_iterator_t {index, neighbors, visits, 0}.skip_missing(); - } - candidates_iterator_t end() const noexcept { - return {index, neighbors, visits, neighbors.size()}; - } - }; - - template - std::size_t search_for_one_( // - value_at &&query, metric_at &&metric, prefetch_at &&prefetch, // - std::size_t closest_slot, level_t begin_level, level_t end_level, context_t &context) const noexcept { - - visits_hash_set_t &visits = context.visits; - visits.clear(); - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) - prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); - - distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); - for (level_t level = begin_level; level > end_level; --level) { - bool changed; - do { - changed = false; - node_lock_t closest_lock = node_lock_(closest_slot); - neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) { - candidates_range_t missing_candidates {*this, closest_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Actual traversal - for (compressed_slot_t candidate_slot : closest_neighbors) { - distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); - if (candidate_dist < closest_dist) { - closest_dist = candidate_dist; - closest_slot = candidate_slot; - changed = true; - } - } - context.iteration_cycles++; - } while (changed); - } - return closest_slot; - } - - /** - * @brief Traverses a layer of a graph, to find the best place to insert a new node. - * Locks the nodes in the process, assuming other threads are updating neighbors lists. - * @return `true` if procedure succeeded, `false` if run out of memory. - */ - template - bool search_to_insert_( // - value_at &&query, metric_at &&metric, prefetch_at &&prefetch, // - std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, - context_t &context) noexcept { - - visits_hash_set_t &visits = context.visits; - next_candidates_t &next = context.next_candidates; // pop min, push - top_candidates_t &top = context.top_candidates; // pop max, push - - visits.clear(); - next.clear(); - top.clear(); - if (!visits.reserve(config_.connectivity_base + 1u)) - return false; - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); - - distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - top.insert_reserved({radius, static_cast(start_slot)}); - visits.set(start_slot); - - while (!next.empty()) { - - candidate_t candidacy = next.top(); - if ((-candidacy.distance) > radius && top.size() == top_limit) - break; - - next.pop(); - context.iteration_cycles++; - - compressed_slot_t candidate_slot = candidacy.slot; - if (new_slot == candidate_slot) - continue; - node_t candidate_ref = node_at_(candidate_slot); - node_lock_t candidate_lock = node_lock_(candidate_slot); - neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) { - candidates_range_t missing_candidates {*this, candidate_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Assume the worst-case when reserving memory - if (!visits.reserve(visits.size() + candidate_neighbors.size())) - return false; - - for (compressed_slot_t successor_slot : candidate_neighbors) { - if (visits.set(successor_slot)) - continue; - - // node_lock_t successor_lock = node_lock_(successor_slot); - distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); - if (top.size() < top_limit || successor_dist < radius) { - // This can substantially grow our priority queue: - next.insert({-successor_dist, successor_slot}); - // This will automatically evict poor matches: - top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; - } - } - } - return true; - } - - /** - * @brief Traverses the @b base layer of a graph, to find a close match. - * Doesn't lock any nodes, assuming read-only simultaneous access. - * @return `true` if procedure succeeded, `false` if run out of memory. - */ - template - bool search_to_find_in_base_( // - value_at &&query, metric_at &&metric, predicate_at &&predicate, prefetch_at &&prefetch, // - std::size_t start_slot, std::size_t expansion, context_t &context) const noexcept { - - visits_hash_set_t &visits = context.visits; - next_candidates_t &next = context.next_candidates; // pop min, push - top_candidates_t &top = context.top_candidates; // pop max, push - std::size_t const top_limit = expansion; - - visits.clear(); - next.clear(); - top.clear(); - if (!visits.reserve(config_.connectivity_base + 1u)) - return false; - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); - - distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - top.insert_reserved({radius, static_cast(start_slot)}); - visits.set(start_slot); - - while (!next.empty()) { - - candidate_t candidate = next.top(); - if ((-candidate.distance) > radius) - break; - - next.pop(); - context.iteration_cycles++; - - neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); - - // Optional prefetching - if (!std::is_same::type, dummy_prefetch_t>::value) { - candidates_range_t missing_candidates {*this, candidate_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Assume the worst-case when reserving memory - if (!visits.reserve(visits.size() + candidate_neighbors.size())) - return false; - - for (compressed_slot_t successor_slot : candidate_neighbors) { - if (visits.set(successor_slot)) - continue; - - distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); - if (top.size() < top_limit || successor_dist < radius) { - // This can substantially grow our priority queue: - next.insert({-successor_dist, successor_slot}); - if (!is_dummy()) - if (!predicate(member_cref_t {node_at_(successor_slot).ckey(), successor_slot})) - continue; - - // This will automatically evict poor matches: - top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; - } - } - } - - return true; - } - - /** - * @brief Iterates through all members, without actually touching the index. - */ - template - void search_exact_( // - value_at &&query, metric_at &&metric, predicate_at &&predicate, // - std::size_t count, context_t &context) const noexcept { - - top_candidates_t &top = context.top_candidates; - top.clear(); - top.reserve(count); - for (std::size_t i = 0; i != size(); ++i) { - if (!is_dummy()) - if (!predicate(at(i))) - continue; - - distance_t distance = context.measure(query, citerator_at(i), metric); - top.insert(candidate_t {distance, static_cast(i)}, count); - } - } - - /** - * @brief This algorithm from the original paper implements a heuristic, - * that massively reduces the number of connections a point has, - * to keep only the neighbors, that are from each other. - */ - template - candidates_view_t refine_( // - metric_at &&metric, // - std::size_t needed, top_candidates_t &top, context_t &context) const noexcept { - - top.sort_ascending(); - candidate_t *top_data = top.data(); - std::size_t const top_count = top.size(); - if (top_count < needed) - return {top_data, top_count}; - - std::size_t submitted_count = 1; - std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. - while (submitted_count < needed && consumed_count < top_count) { - candidate_t candidate = top_data[consumed_count]; - bool good = true; - for (std::size_t idx = 0; idx < submitted_count; idx++) { - candidate_t submitted = top_data[idx]; - distance_t inter_result_dist = context.measure( // - citerator_at(candidate.slot), // - citerator_at(submitted.slot), // - metric); - if (inter_result_dist < candidate.distance) { - good = false; - break; - } - } - - if (good) { - top_data[submitted_count] = top_data[consumed_count]; - submitted_count++; - } - consumed_count++; - } - - top.shrink(submitted_count); - return {top_data, submitted_count}; - } + inline static precomputed_constants_t precompute_(index_config_t const& config) noexcept { + precomputed_constants_t pre; + pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); + pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + return pre; + } + + using span_bytes_t = span_gt; + + inline span_bytes_t node_bytes_(node_t node) const noexcept { return {node.tape(), node_bytes_(node.level())}; } + inline std::size_t node_bytes_(level_t level) const noexcept { + return node_head_bytes_() + node_neighbors_bytes_(level); + } + inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); } + inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { + return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; + } + + span_bytes_t node_malloc_(level_t level) noexcept { + std::size_t node_bytes = node_bytes_(level); + byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes); + return data ? span_bytes_t{data, node_bytes} : span_bytes_t{}; + } + + node_t node_make_(vector_key_t key, level_t level) noexcept { + span_bytes_t node_bytes = node_malloc_(level); + if (!node_bytes) + return {}; + + std::memset(node_bytes.data(), 0, node_bytes.size()); + node_t node{(byte_t*)node_bytes.data()}; + node.key(key); + node.level(level); + return node; + } + + node_t node_make_copy_(span_bytes_t old_bytes) noexcept { + byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.size()); + if (!data) + return {}; + std::memcpy(data, old_bytes.data(), old_bytes.size()); + return node_t{data}; + } + + void node_free_(std::size_t idx) noexcept { + if (viewed_file_) + return; + + node_t& node = nodes_[idx]; + tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); + node = node_t{}; + } + + inline node_t node_at_(std::size_t idx) const noexcept { return nodes_[idx]; } + inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } + + inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { + return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; + } + + inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { + return level ? neighbors_non_base_(node, level) : neighbors_base_(node); + } + + struct node_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_lock_t() noexcept { mutexes.atomic_reset(slot); } + }; + + inline node_lock_t node_lock_(std::size_t slot) const noexcept { + while (nodes_mutexes_.atomic_set(slot)) + ; + return {nodes_mutexes_, slot}; + } + + template + void connect_node_across_levels_( // + value_at&& value, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // + index_update_config_t const& config, context_t& context) usearch_noexcept_m { + + // Go down the level, tracking only the closest match + std::size_t closest_slot = search_for_one_( // + value, metric, prefetch, // + entry_slot, max_level, target_level, context); + + // From `target_level` down perform proper extensive search + for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { + // TODO: Handle out of memory conditions + search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); + closest_slot = connect_new_node_(metric, node_slot, level, context); + reconnect_neighbor_nodes_(metric, node_slot, value, level, context); + } + } + + template + std::size_t connect_new_node_( // + metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + + // Outgoing links from `new_slot`: + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + { + usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); + candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); + + for (std::size_t idx = 0; idx != top_view.size(); idx++) { + usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); + usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); + new_neighbors.push_back(top_view[idx].slot); + } + } + + return new_neighbors[0]; + } + + template + void reconnect_neighbor_nodes_( // + metric_at&& metric, std::size_t new_slot, value_at&& value, level_t level, + context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + + // Reverse links from the neighbors: + std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; + for (compressed_slot_t close_slot : new_neighbors) { + if (close_slot == new_slot) + continue; + node_lock_t close_lock = node_lock_(close_slot); + node_t close_node = node_at_(close_slot); + + neighbors_ref_t close_header = neighbors_(close_node, level); + usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); + usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); + usearch_assert_m(level <= close_node.level(), "Linking to missing level"); + + // If `new_slot` is already present in the neighboring connections of `close_slot` + // then no need to modify any connections or run the heuristics. + if (close_header.size() < connectivity_max) { + close_header.push_back(static_cast(new_slot)); + continue; + } + + // To fit a new connection we need to drop an existing one. + top.clear(); + usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); + top.insert_reserved( + {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); + for (compressed_slot_t successor_slot : close_header) + top.insert_reserved( + {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); + + // Export the results: + close_header.clear(); + candidates_view_t top_view = refine_(metric, connectivity_max, top, context); + for (std::size_t idx = 0; idx != top_view.size(); idx++) + close_header.push_back(top_view[idx].slot); + } + } + + level_t choose_random_level_(std::default_random_engine& level_generator) const noexcept { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; + return (level_t)r; + } + + struct candidates_range_t; + class candidates_iterator_t { + friend struct candidates_range_t; + + index_gt const& index_; + neighbors_ref_t neighbors_; + visits_hash_set_t& visits_; + std::size_t current_; + + candidates_iterator_t& skip_missing() noexcept { + if (!visits_.size()) + return *this; + while (current_ != neighbors_.size()) { + compressed_slot_t neighbor_slot = neighbors_[current_]; + if (visits_.test(neighbor_slot)) + current_++; + else + break; + } + return *this; + } + + public: + using element_t = compressed_slot_t; + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return slot(); } + candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, + std::size_t progress) noexcept + : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} + candidates_iterator_t operator++(int) noexcept { + return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); + } + candidates_iterator_t& operator++() noexcept { + ++current_; + skip_missing(); + return *this; + } + bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } + bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } + + vector_key_t key() const noexcept { return index_->node_at_(slot()).key(); } + compressed_slot_t slot() const noexcept { return neighbors_[current_]; } + friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } + friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } + }; + + struct candidates_range_t { + index_gt const& index; + neighbors_ref_t neighbors; + visits_hash_set_t& visits; + + candidates_iterator_t begin() const noexcept { + return candidates_iterator_t{index, neighbors, visits, 0}.skip_missing(); + } + candidates_iterator_t end() const noexcept { return {index, neighbors, visits, neighbors.size()}; } + }; + + template + std::size_t search_for_one_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + visits.clear(); + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); + + distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); + for (level_t level = begin_level; level > end_level; --level) { + bool changed; + do { + changed = false; + node_lock_t closest_lock = node_lock_(closest_slot); + neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, closest_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Actual traversal + for (compressed_slot_t candidate_slot : closest_neighbors) { + distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); + if (candidate_dist < closest_dist) { + closest_dist = candidate_dist; + closest_slot = candidate_slot; + changed = true; + } + } + context.iteration_cycles++; + } while (changed); + } + return closest_slot; + } + + /** + * @brief Traverses a layer of a graph, to find the best place to insert a new node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_insert_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, + context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, static_cast(start_slot)}); + top.insert_reserved({radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + if (new_slot == candidate_slot) + continue; + node_t candidate_ref = node_at_(candidate_slot); + node_lock_t candidate_lock = node_lock_(candidate_slot); + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // node_lock_t successor_lock = node_lock_(successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + + /** + * @brief Traverses the @b base layer of a graph, to find a close match. + * Doesn't lock any nodes, assuming read-only simultaneous access. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_find_in_base_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t expansion, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + std::size_t const top_limit = expansion; + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, static_cast(start_slot)}); + top.insert_reserved({radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + while (!next.empty()) { + + candidate_t candidate = next.top(); + if ((-candidate.distance) > radius) + break; + + next.pop(); + context.iteration_cycles++; + + neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + if (!is_dummy()) + if (!predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) + continue; + + // This will automatically evict poor matches: + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + + return true; + } + + /** + * @brief Iterates through all members, without actually touching the index. + */ + template + void search_exact_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, // + std::size_t count, context_t& context) const noexcept { + + top_candidates_t& top = context.top_candidates; + top.clear(); + top.reserve(count); + for (std::size_t i = 0; i != size(); ++i) { + if (!is_dummy()) + if (!predicate(at(i))) + continue; + + distance_t distance = context.measure(query, citerator_at(i), metric); + top.insert(candidate_t{distance, static_cast(i)}, count); + } + } + + /** + * @brief This algorithm from the original paper implements a heuristic, + * that massively reduces the number of connections a point has, + * to keep only the neighbors, that are from each other. + */ + template + candidates_view_t refine_( // + metric_at&& metric, // + std::size_t needed, top_candidates_t& top, context_t& context) const noexcept { + + top.sort_ascending(); + candidate_t* top_data = top.data(); + std::size_t const top_count = top.size(); + if (top_count < needed) + return {top_data, top_count}; + + std::size_t submitted_count = 1; + std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. + while (submitted_count < needed && consumed_count < top_count) { + candidate_t candidate = top_data[consumed_count]; + bool good = true; + for (std::size_t idx = 0; idx < submitted_count; idx++) { + candidate_t submitted = top_data[idx]; + distance_t inter_result_dist = context.measure( // + citerator_at(candidate.slot), // + citerator_at(submitted.slot), // + metric); + if (inter_result_dist < candidate.distance) { + good = false; + break; + } + } + + if (good) { + top_data[submitted_count] = top_data[consumed_count]; + submitted_count++; + } + consumed_count++; + } + + top.shrink(submitted_count); + return {top_data, submitted_count}; + } }; struct join_result_t { - error_t error {}; - std::size_t intersection_size {}; - std::size_t engagements {}; - std::size_t visited_members {}; - std::size_t computed_distances {}; - - explicit operator bool() const noexcept { - return !error; - } - join_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + error_t error{}; + std::size_t intersection_size{}; + std::size_t engagements{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + join_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** - * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets - * to perform fast one-to-one matching between two large collections - * of vectors, using approximate nearest neighbors search. - * - * @param[inout] man_to_woman Container to map ::first keys to ::second. - * @param[inout] woman_to_man Container to map ::second keys to ::first. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - */ +* @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets +* to perform fast one-to-one matching between two large collections +* of vectors, using approximate nearest neighbors search. +* +* @param[inout] man_to_woman Container to map ::first keys to ::second. +* @param[inout] woman_to_man Container to map ::second keys to ::first. +* @param[in] executor Thread-pool to execute the job in parallel. +* @param[in] progress Callback to report the execution progress. +*/ template < // - typename men_at, // - typename women_at, // - typename men_values_at, // - typename women_values_at, // - typename men_metric_at, // - typename women_metric_at, // - - typename man_to_woman_at = dummy_key_to_key_mapping_t, // - typename woman_to_man_at = dummy_key_to_key_mapping_t, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > + typename men_at, // + typename women_at, // + typename men_values_at, // + typename women_values_at, // + typename men_metric_at, // + typename women_metric_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > static join_result_t join( // - men_at const &men, // - women_at const &women, // - men_values_at const &men_values, // - women_values_at const &women_values, // - men_metric_at &&men_metric, // - women_metric_at &&women_metric, // - - index_join_config_t config = {}, // - man_to_woman_at &&man_to_woman = man_to_woman_at {}, // - woman_to_man_at &&woman_to_man = woman_to_man_at {}, // - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}) noexcept { - - if (women.size() < men.size()) - return unum::usearch::join( // - women, men, // - women_values, men_values, // - std::forward(women_metric), std::forward(men_metric), // - - config, // - std::forward(woman_to_man), // - std::forward(man_to_woman), // - std::forward(executor), // - std::forward(progress)); - - join_result_t result; - - // Sanity checks and argument validation: - if (&men == &women) - return result.failed("Can't join with itself, consider copying"); - - if (config.max_proposals == 0) - config.max_proposals = std::log(men.size()) + executor.size(); - - using proposals_count_t = std::uint16_t; - config.max_proposals = (std::min)(men.size(), config.max_proposals); - - using distance_t = typename men_at::distance_t; - using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; - using man_key_t = typename men_at::vector_key_t; - using woman_key_t = typename women_at::vector_key_t; - - // Use the `compressed_slot_t` type of the larger collection - using compressed_slot_t = typename women_at::compressed_slot_t; - using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - // Create an atomic queue, as a ring structure, from/to which - // free men will be added/pulled. - std::mutex free_men_mutex {}; - ring_gt free_men; - free_men.reserve(men.size()); - for (std::size_t i = 0; i != men.size(); ++i) - free_men.push(static_cast(i)); - - // We are gonna need some temporary memory. - buffer_gt proposal_counts(men.size()); - buffer_gt man_to_woman_slots(men.size()); - buffer_gt woman_to_man_slots(women.size()); - if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) - return result.failed("Can't temporary mappings"); - - compressed_slot_t missing_slot; - std::memset((void *)&missing_slot, 0xFF, sizeof(compressed_slot_t)); - std::memset((void *)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); - std::memset((void *)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); - std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); - - // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. - bitset_t men_locks(men.size()), women_locks(women.size()); - if (!men_locks || !women_locks) - return result.failed("Can't allocate locks"); - - std::atomic rounds {0}; - std::atomic engagements {0}; - std::atomic computed_distances {0}; - std::atomic visited_members {0}; - std::atomic atomic_error {nullptr}; - - // Concurrently process all the men - executor.parallel([&](std::size_t thread_idx) { - index_search_config_t search_config; - search_config.expansion = config.expansion; - search_config.exact = config.exact; - search_config.thread = thread_idx; - compressed_slot_t free_man_slot; - - // While there exist a free man who still has a woman to propose to. - while (!atomic_error.load(std::memory_order_relaxed)) { - std::size_t passed_rounds = 0; - std::size_t total_rounds = 0; - { - std::unique_lock pop_lock(free_men_mutex); - if (!free_men.try_pop(free_man_slot)) - // Primary exit path, we have exhausted the list of candidates - break; - passed_rounds = ++rounds; - total_rounds = passed_rounds + free_men.size(); - } - if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { - atomic_error.store("Terminated by user"); - break; - } - while (men_locks.atomic_set(free_man_slot)) - ; - - proposals_count_t &free_man_proposals = proposal_counts[free_man_slot]; - if (free_man_proposals >= config.max_proposals) - continue; - - // Find the closest woman, to whom this man hasn't proposed yet. - ++free_man_proposals; - auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); - visited_members += candidates.visited_members; - computed_distances += candidates.computed_distances; - if (!candidates) { - atomic_error = candidates.error.release(); - break; - } - - auto match = candidates.back(); - auto woman = match.member; - while (women_locks.atomic_set(woman.slot)) - ; - - compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; - bool woman_is_free = husband_slot == missing_slot; - if (woman_is_free) { - // Engagement - man_to_woman_slots[free_man_slot] = woman.slot; - woman_to_man_slots[woman.slot] = free_man_slot; - engagements++; - } else { - distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); - distance_t distance_from_candidate = match.distance; - if (distance_from_husband > distance_from_candidate) { - // Break-up - while (men_locks.atomic_set(husband_slot)) - ; - man_to_woman_slots[husband_slot] = missing_slot; - men_locks.atomic_reset(husband_slot); - - // New Engagement - man_to_woman_slots[free_man_slot] = woman.slot; - woman_to_man_slots[woman.slot] = free_man_slot; - engagements++; - - std::unique_lock push_lock(free_men_mutex); - free_men.push(husband_slot); - } else { - std::unique_lock push_lock(free_men_mutex); - free_men.push(free_man_slot); - } - } - - men_locks.atomic_reset(free_man_slot); - women_locks.atomic_reset(woman.slot); - } - }); - - if (atomic_error) - return result.failed(atomic_error.load()); - - // Export the "slots" into keys: - std::size_t intersection_size = 0; - for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { - compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; - if (woman_slot != missing_slot) { - man_key_t man = men.at(man_slot).key; - woman_key_t woman = women.at(woman_slot).key; - man_to_woman[man] = woman; - woman_to_man[woman] = man; - intersection_size++; - } - } - - // Export stats - result.engagements = engagements; - result.intersection_size = intersection_size; - result.computed_distances = computed_distances; - result.visited_members = visited_members; - return result; + men_at const& men, // + women_at const& women, // + men_values_at const& men_values, // + women_values_at const& women_values, // + men_metric_at&& men_metric, // + women_metric_at&& women_metric, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + if (women.size() < men.size()) + return unum::usearch::join( // + women, men, // + women_values, men_values, // + std::forward(women_metric), std::forward(men_metric), // + + config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); + + join_result_t result; + + // Sanity checks and argument validation: + if (&men == &women) + return result.failed("Can't join with itself, consider copying"); + + if (config.max_proposals == 0) + config.max_proposals = std::log(men.size()) + executor.size(); + + using proposals_count_t = std::uint16_t; + config.max_proposals = (std::min)(men.size(), config.max_proposals); + + using distance_t = typename men_at::distance_t; + using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; + using man_key_t = typename men_at::vector_key_t; + using woman_key_t = typename women_at::vector_key_t; + + // Use the `compressed_slot_t` type of the larger collection + using compressed_slot_t = typename women_at::compressed_slot_t; + using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + // Create an atomic queue, as a ring structure, from/to which + // free men will be added/pulled. + std::mutex free_men_mutex{}; + ring_gt free_men; + free_men.reserve(men.size()); + for (std::size_t i = 0; i != men.size(); ++i) + free_men.push(static_cast(i)); + + // We are gonna need some temporary memory. + buffer_gt proposal_counts(men.size()); + buffer_gt man_to_woman_slots(men.size()); + buffer_gt woman_to_man_slots(women.size()); + if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) + return result.failed("Can't temporary mappings"); + + compressed_slot_t missing_slot; + std::memset((void*)&missing_slot, 0xFF, sizeof(compressed_slot_t)); + std::memset((void*)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); + std::memset((void*)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); + std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); + + // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. + bitset_t men_locks(men.size()), women_locks(women.size()); + if (!men_locks || !women_locks) + return result.failed("Can't allocate locks"); + + std::atomic rounds{0}; + std::atomic engagements{0}; + std::atomic computed_distances{0}; + std::atomic visited_members{0}; + std::atomic atomic_error{nullptr}; + + // Concurrently process all the men + executor.parallel([&](std::size_t thread_idx) { + index_search_config_t search_config; + search_config.expansion = config.expansion; + search_config.exact = config.exact; + search_config.thread = thread_idx; + compressed_slot_t free_man_slot; + + // While there exist a free man who still has a woman to propose to. + while (!atomic_error.load(std::memory_order_relaxed)) { + std::size_t passed_rounds = 0; + std::size_t total_rounds = 0; + { + std::unique_lock pop_lock(free_men_mutex); + if (!free_men.try_pop(free_man_slot)) + // Primary exit path, we have exhausted the list of candidates + break; + passed_rounds = ++rounds; + total_rounds = passed_rounds + free_men.size(); + } + if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { + atomic_error.store("Terminated by user"); + break; + } + while (men_locks.atomic_set(free_man_slot)) + ; + + proposals_count_t& free_man_proposals = proposal_counts[free_man_slot]; + if (free_man_proposals >= config.max_proposals) + continue; + + // Find the closest woman, to whom this man hasn't proposed yet. + ++free_man_proposals; + auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); + visited_members += candidates.visited_members; + computed_distances += candidates.computed_distances; + if (!candidates) { + atomic_error = candidates.error.release(); + break; + } + + auto match = candidates.back(); + auto woman = match.member; + while (women_locks.atomic_set(woman.slot)) + ; + + compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; + bool woman_is_free = husband_slot == missing_slot; + if (woman_is_free) { + // Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + } else { + distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); + distance_t distance_from_candidate = match.distance; + if (distance_from_husband > distance_from_candidate) { + // Break-up + while (men_locks.atomic_set(husband_slot)) + ; + man_to_woman_slots[husband_slot] = missing_slot; + men_locks.atomic_reset(husband_slot); + + // New Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + + std::unique_lock push_lock(free_men_mutex); + free_men.push(husband_slot); + } else { + std::unique_lock push_lock(free_men_mutex); + free_men.push(free_man_slot); + } + } + + men_locks.atomic_reset(free_man_slot); + women_locks.atomic_reset(woman.slot); + } + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Export the "slots" into keys: + std::size_t intersection_size = 0; + for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { + compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; + if (woman_slot != missing_slot) { + man_key_t man = men.at(man_slot).key; + woman_key_t woman = women.at(woman_slot).key; + man_to_woman[man] = woman; + woman_to_man[woman] = man; + intersection_size++; + } + } + + // Export stats + result.engagements = engagements; + result.intersection_size = intersection_size; + result.computed_distances = computed_distances; + result.visited_members = visited_members; + return result; } } // namespace usearch diff --git a/src/include/usearch/index_dense.hpp b/src/include/usearch/index_dense.hpp index 0e5b1ec..e697e82 100644 --- a/src/include/usearch/index_dense.hpp +++ b/src/include/usearch/index_dense.hpp @@ -16,16 +16,13 @@ namespace unum { namespace usearch { -template -class index_dense_gt; +template class index_dense_gt; /** * @brief The "magic" sequence helps infer the type of the file. * USearch indexes start with the "usearch" string. */ -constexpr char const *default_magic() { - return "usearch"; -} +constexpr char const* default_magic() { return "usearch"; } using index_dense_head_buffer_t = byte_t[64]; @@ -47,7 +44,7 @@ struct index_dense_head_t { using version_t = std::uint16_t; // Versioning: 7 + 2 * 3 = 13 bytes - char const *magic; + char const* magic; misaligned_ref_gt version_major; misaligned_ref_gt version_minor; misaligned_ref_gt version_patch; @@ -64,8 +61,8 @@ struct index_dense_head_t { misaligned_ref_gt dimensions; misaligned_ref_gt multi; - index_dense_head_t(byte_t *ptr) noexcept - : magic((char const *)exchange(ptr, ptr + sizeof(magic_t))), // + index_dense_head_t(byte_t* ptr) noexcept + : magic((char const*)exchange(ptr, ptr + sizeof(magic_t))), // version_major(exchange(ptr, ptr + sizeof(version_t))), // version_minor(exchange(ptr, ptr + sizeof(version_t))), // version_patch(exchange(ptr, ptr + sizeof(version_t))), // @@ -76,8 +73,7 @@ struct index_dense_head_t { count_present(exchange(ptr, ptr + sizeof(std::uint64_t))), // count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))), // dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))), // - multi(exchange(ptr, ptr + sizeof(bool))) { - } + multi(exchange(ptr, ptr + sizeof(bool))) {} }; struct index_dense_head_result_t { @@ -86,9 +82,7 @@ struct index_dense_head_result_t { index_dense_head_t head; error_t error; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } index_dense_head_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); @@ -102,24 +96,22 @@ struct index_dense_config_t : public index_config_t { bool multi = false; /** - * @brief Allows you to reduce RAM consumption by avoiding - * reverse-indexing keys-to-vectors, and only keeping - * the vectors-to-keys mappings. - * - * ! This configuration parameter doesn't affect the serialized file, - * ! and is not preserved between runs. Makes sense for small vector - * ! representations that fit ina single cache line. + * @brief Allows you to reduce RAM consumption by avoiding + * reverse-indexing keys-to-vectors, and only keeping + * the vectors-to-keys mappings. + * + * ! This configuration parameter doesn't affect the serialized file, + * ! and is not preserved between runs. Makes sense for small vector + * ! representations that fit ina single cache line. */ bool enable_key_lookups = true; - index_dense_config_t(index_config_t base) noexcept : index_config_t(base) { - } + index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} index_dense_config_t(std::size_t c = default_connectivity(), std::size_t ea = default_expansion_add(), std::size_t es = default_expansion_search()) noexcept : index_config_t(c), expansion_add(ea ? ea : default_expansion_add()), - expansion_search(es ? es : default_expansion_search()) { - } + expansion_search(es ? es : default_expansion_search()) {} }; struct index_dense_clustering_config_t { @@ -140,8 +132,7 @@ struct index_dense_copy_config_t : public index_copy_config_t { bool force_vector_copy = true; index_dense_copy_config_t() = default; - index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) { - } + index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) {} }; struct index_dense_metadata_result_t { @@ -150,24 +141,21 @@ struct index_dense_metadata_result_t { index_dense_head_t head; error_t error; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } index_dense_metadata_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); } - index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() { - } + index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() {} - index_dense_metadata_result_t(index_dense_metadata_result_t &&other) noexcept + index_dense_metadata_result_t(index_dense_metadata_result_t&& other) noexcept : config(), head_buffer(), head(head_buffer), error(std::move(other.error)) { std::memcpy(&config, &other.config, sizeof(other.config)); std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); } - index_dense_metadata_result_t &operator=(index_dense_metadata_result_t &&other) noexcept { + index_dense_metadata_result_t& operator=(index_dense_metadata_result_t&& other) noexcept { std::memcpy(&config, &other.config, sizeof(other.config)); std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); error = std::move(other.error); @@ -179,9 +167,9 @@ struct index_dense_metadata_result_t { * @brief Extracts metadata from a pre-constructed index on disk, * without loading it or mapping the whole binary file. */ -inline index_dense_metadata_result_t index_dense_metadata_from_path(char const *file_path) noexcept { +inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* file_path) noexcept { index_dense_metadata_result_t result; - std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); + std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); if (!file) return result.failed(std::strerror(errno)); @@ -201,17 +189,17 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const * // Check if it starts with 32-bit std::size_t const file_size = std::ftell(file.get()); - std::uint32_t dimensions_u32[2] {0}; + std::uint32_t dimensions_u32[2]{0}; std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); - std::uint64_t dimensions_u64[2] {0}; + std::uint64_t dimensions_u64[2]{0}; std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); // Check if it starts with 32-bit if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { - if (std::fseek(file.get(), offset_if_u32, SEEK_SET) != 0) + if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) return result.failed(std::strerror(errno)); read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); if (!read) @@ -225,7 +213,7 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const * // Check if it starts with 64-bit if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { - if (std::fseek(file.get(), offset_if_u64, SEEK_SET) != 0) + if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) return result.failed(std::strerror(errno)); read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); if (!read) @@ -252,7 +240,7 @@ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_map if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) return result.failed("End of file reached!"); - byte_t *const file_data = file.data() + offset; + byte_t* const file_data = file.data() + offset; std::size_t const file_size = file.size() - offset; std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); @@ -262,11 +250,11 @@ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_map return result; // Check if it starts with 32-bit - std::uint32_t dimensions_u32[2] {0}; + std::uint32_t dimensions_u32[2]{0}; std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); - std::uint64_t dimensions_u64[2] {0}; + std::uint64_t dimensions_u64[2]{0}; std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); @@ -330,7 +318,7 @@ class index_dense_gt { private: /// @brief Schema: input buffer, bytes in input buffer, output buffer. - using cast_t = std::function; + using cast_t = std::function; /// @brief Punned index. using index_t = index_gt< // distance_t, vector_key_t, compressed_slot_t, // @@ -342,43 +330,28 @@ class index_dense_gt { /// @brief Punned metric object. class metric_proxy_t { - index_dense_gt const *index_ = nullptr; + index_dense_gt const* index_ = nullptr; public: - metric_proxy_t(index_dense_gt const &index) noexcept : index_(&index) { - } + metric_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} - inline distance_t operator()(byte_t const *a, member_cref_t b) const noexcept { - return f(a, v(b)); - } - inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { - return f(v(a), v(b)); - } + inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } - inline distance_t operator()(byte_t const *a, member_citerator_t b) const noexcept { - return f(a, v(b)); - } + inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); } inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept { return f(v(a), v(b)); } - inline distance_t operator()(byte_t const *a, byte_t const *b) const noexcept { - return f(a, b); - } + inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); } - inline byte_t const *v(member_cref_t m) const noexcept { - return index_->vectors_lookup_[get_slot(m)]; - } - inline byte_t const *v(member_citerator_t m) const noexcept { - return index_->vectors_lookup_[get_slot(m)]; - } - inline distance_t f(byte_t const *a, byte_t const *b) const noexcept { - return index_->metric_(a, b); - } + inline byte_t const* v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline byte_t const* v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return index_->metric_(a, b); } }; index_dense_config_t config_; - index_t *typed_ = nullptr; + index_t* typed_ = nullptr; mutable std::vector cast_buffer_; struct casts_t { @@ -403,7 +376,7 @@ class index_dense_gt { vectors_tape_allocator_t vectors_tape_allocator_; /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. - mutable std::vector vectors_lookup_; + mutable std::vector vectors_lookup_; /// @brief Originally forms and array of integers [0, threads], marking all mutable std::vector available_threads_; @@ -423,35 +396,21 @@ class index_dense_gt { vector_key_t key; compressed_slot_t slot; - bool any_slot() const { - return slot == default_free_value(); - } - static key_and_slot_t any_slot(vector_key_t key) { - return {key, default_free_value()}; - } + bool any_slot() const { return slot == default_free_value(); } + static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } }; struct lookup_key_hash_t { using is_transparent = void; - std::size_t operator()(key_and_slot_t const &k) const noexcept { - return std::hash {}(k.key); - } - std::size_t operator()(vector_key_t const &k) const noexcept { - return std::hash {}(k); - } + std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash{}(k.key); } + std::size_t operator()(vector_key_t const& k) const noexcept { return std::hash{}(k); } }; struct lookup_key_same_t { using is_transparent = void; - bool operator()(key_and_slot_t const &a, vector_key_t const &b) const noexcept { - return a.key == b; - } - bool operator()(vector_key_t const &a, key_and_slot_t const &b) const noexcept { - return a == b.key; - } - bool operator()(key_and_slot_t const &a, key_and_slot_t const &b) const noexcept { - return a.key == b.key; - } + bool operator()(key_and_slot_t const& a, vector_key_t const& b) const noexcept { return a.key == b; } + bool operator()(vector_key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; } + bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept { return a.key == b.key; } }; /// @brief Multi-Map from keys to IDs, and allocated vectors. @@ -477,7 +436,7 @@ class index_dense_gt { using match_t = typename index_t::match_t; index_dense_gt() = default; - index_dense_gt(index_dense_gt &&other) + index_dense_gt(index_dense_gt&& other) : config_(std::move(other.config_)), typed_(exchange(other.typed_, nullptr)), // @@ -491,19 +450,18 @@ class index_dense_gt { available_threads_(std::move(other.available_threads_)), // slot_lookup_(std::move(other.slot_lookup_)), // free_keys_(std::move(other.free_keys_)), // - free_key_(std::move(other.free_key_)) { - } // + free_key_(std::move(other.free_key_)) {} // - index_dense_gt &operator=(index_dense_gt &&other) { + index_dense_gt& operator=(index_dense_gt&& other) { swap(other); return *this; } /** - * @brief Swaps the contents of this index with another index. - * @param other The other index to swap with. + * @brief Swaps the contents of this index with another index. + * @param other The other index to swap with. */ - void swap(index_dense_gt &other) { + void swap(index_dense_gt& other) { std::swap(config_, other.config_); std::swap(typed_, other.typed_); @@ -523,16 +481,16 @@ class index_dense_gt { ~index_dense_gt() { if (typed_) typed_->~index_t(); - index_allocator_t {}.deallocate(typed_, 1); + index_allocator_t{}.deallocate(typed_, 1); typed_ = nullptr; } /** - * @brief Constructs an instance of ::index_dense_gt. - * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. - * @param[in] config The index configuration (optional). - * @param[in] free_key The key used for freed vectors (optional). - * @return An instance of ::index_dense_gt. + * @brief Constructs an instance of ::index_dense_gt. + * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. + * @param[in] config The index configuration (optional). + * @param[in] free_key The key used for freed vectors (optional). + * @return An instance of ::index_dense_gt. */ static index_dense_gt make( // metric_t metric, // @@ -554,13 +512,13 @@ class index_dense_gt { std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); // Available since C11, but only C++17, so we use the C version. - index_t *raw = index_allocator_t {}.allocate(1); + index_t* raw = index_allocator_t{}.allocate(1); new (raw) index_t(config); result.typed_ = raw; return result; } - static index_dense_gt make(char const *path, bool view = false) { + static index_dense_gt make(char const* path, bool view = false) { index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); if (!meta) return {}; @@ -575,107 +533,51 @@ class index_dense_gt { return result; } - explicit operator bool() const { - return typed_; - } - std::size_t connectivity() const { - return typed_->connectivity(); - } - std::size_t size() const { - return typed_->size() - free_keys_.size(); - } - std::size_t capacity() const { - return typed_->capacity(); - } - std::size_t max_level() const noexcept { - return typed_->max_level(); - } - index_dense_config_t const &config() const { - return config_; - } - index_limits_t const &limits() const { - return typed_->limits(); - } - bool multi() const { - return config_.multi; - } + explicit operator bool() const { return typed_; } + std::size_t connectivity() const { return typed_->connectivity(); } + std::size_t size() const { return typed_->size() - free_keys_.size(); } + std::size_t capacity() const { return typed_->capacity(); } + std::size_t max_level() const noexcept { return typed_->max_level(); } + index_dense_config_t const& config() const { return config_; } + index_limits_t const& limits() const { return typed_->limits(); } + bool multi() const { return config_.multi; } // The metric and its properties - metric_t const &metric() const { - return metric_; - } - void change_metric(metric_t metric) { - metric_ = std::move(metric); - } + metric_t const& metric() const { return metric_; } + void change_metric(metric_t metric) { metric_ = std::move(metric); } - scalar_kind_t scalar_kind() const noexcept { - return metric_.scalar_kind(); - } - std::size_t bytes_per_vector() const noexcept { - return metric_.bytes_per_vector(); - } - std::size_t scalar_words() const noexcept { - return metric_.scalar_words(); - } - std::size_t dimensions() const noexcept { - return metric_.dimensions(); - } + scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } + std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } + std::size_t dimensions() const noexcept { return metric_.dimensions(); } // Fetching and changing search criteria - std::size_t expansion_add() const { - return config_.expansion_add; - } - std::size_t expansion_search() const { - return config_.expansion_search; - } - void change_expansion_add(std::size_t n) { - config_.expansion_add = n; - } - void change_expansion_search(std::size_t n) { - config_.expansion_search = n; - } - - member_citerator_t cbegin() const { - return typed_->cbegin(); - } - member_citerator_t cend() const { - return typed_->cend(); - } - member_citerator_t begin() const { - return typed_->begin(); - } - member_citerator_t end() const { - return typed_->end(); - } - member_iterator_t begin() { - return typed_->begin(); - } - member_iterator_t end() { - return typed_->end(); - } - - stats_t stats() const { - return typed_->stats(); - } - stats_t stats(std::size_t level) const { - return typed_->stats(level); - } - stats_t stats(stats_t *stats_per_level, std::size_t max_level) const { + std::size_t expansion_add() const { return config_.expansion_add; } + std::size_t expansion_search() const { return config_.expansion_search; } + void change_expansion_add(std::size_t n) { config_.expansion_add = n; } + void change_expansion_search(std::size_t n) { config_.expansion_search = n; } + + member_citerator_t cbegin() const { return typed_->cbegin(); } + member_citerator_t cend() const { return typed_->cend(); } + member_citerator_t begin() const { return typed_->begin(); } + member_citerator_t end() const { return typed_->end(); } + member_iterator_t begin() { return typed_->begin(); } + member_iterator_t end() { return typed_->end(); } + + stats_t stats() const { return typed_->stats(); } + stats_t stats(std::size_t level) const { return typed_->stats(level); } + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const { return typed_->stats(stats_per_level, max_level); } - dynamic_allocator_t const &allocator() const { - return typed_->dynamic_allocator(); - } - vector_key_t const &free_key() const { - return free_key_; - } + dynamic_allocator_t const& allocator() const { return typed_->dynamic_allocator(); } + vector_key_t const& free_key() const { return free_key_; } /** - * @brief A relatively accurate lower bound on the amount of memory consumed by the system. - * In practice it's error will be below 10%. - * - * @see `serialized_length` for the length of the binary serialized representation. + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. */ std::size_t memory_usage() const { return // @@ -685,12 +587,8 @@ class index_dense_gt { vectors_tape_allocator_.total_allocated(); } - static constexpr std::size_t any_thread() { - return std::numeric_limits::max(); - } - static constexpr distance_t infinite_distance() { - return std::numeric_limits::max(); - } + static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } + static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } struct aggregated_distances_t { std::size_t count = 0; @@ -732,9 +630,9 @@ class index_dense_gt { // clang-format on /** - * @brief Computes the distance between two managed entities. - * If either key maps into more than one vector, will aggregate results - * exporting the mean, maximum, and minimum values. + * @brief Computes the distance between two managed entities. + * If either key maps into more than one vector, will aggregate results + * exporting the mean, maximum, and minimum values. */ aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { shared_lock_t lock(slot_lookup_mutex_); @@ -748,9 +646,9 @@ class index_dense_gt { return result; key_and_slot_t a_key_and_slot = *a_it; - byte_t const *a_vector = vectors_lookup_[a_key_and_slot.slot]; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; key_and_slot_t b_key_and_slot = *b_it; - byte_t const *b_vector = vectors_lookup_[b_key_and_slot.slot]; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; distance_t a_b_distance = metric_(a_vector, b_vector); result.mean = result.min = result.max = a_b_distance; @@ -772,10 +670,10 @@ class index_dense_gt { while (a_range.first != a_range.second) { key_and_slot_t a_key_and_slot = *a_range.first; - byte_t const *a_vector = vectors_lookup_[a_key_and_slot.slot]; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; while (b_range.first != b_range.second) { key_and_slot_t b_key_and_slot = *b_range.first; - byte_t const *b_vector = vectors_lookup_[b_key_and_slot.slot]; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; distance_t a_b_distance = metric_(a_vector, b_vector); result.mean += a_b_distance; @@ -794,7 +692,7 @@ class index_dense_gt { } /** - * @brief Identifies a node in a given `level`, that is the closest to the `key`. + * @brief Identifies a node in a given `level`, that is the closest to the `key`. */ cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const { @@ -809,15 +707,13 @@ class index_dense_gt { thread_lock_t lock = thread_lock_(thread); cluster_config.thread = lock.thread_id; cluster_config.expansion = config_.expansion_search; - metric_proxy_t metric {*this}; - auto allow = [=](member_cref_t const &member) noexcept { - return member.key != free_key_; - }; + metric_proxy_t metric{*this}; + auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; // Find the closest cluster for any vector under that key. while (key_range.first != key_range.second) { key_and_slot_t key_and_slot = *key_range.first; - byte_t const *vector_data = vectors_lookup_[key_and_slot.slot]; + byte_t const* vector_data = vectors_lookup_[key_and_slot.slot]; cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); if (!new_result) return new_result; @@ -830,8 +726,8 @@ class index_dense_gt { } /** - * @brief Reserves memory for the index and the keyed lookup. - * @return `true` if the memory reservation was successful, `false` otherwise. + * @brief Reserves memory for the index and the keyed lookup. + * @return `true` if the memory reservation was successful, `false` otherwise. */ bool reserve(index_limits_t limits) { { @@ -843,10 +739,10 @@ class index_dense_gt { } /** - * @brief Erases all the vectors from the index. - * - * Will change `size()` to zero, but will keep the same `capacity()`. - * Will keep the number of available threads/contexts the same as it was. + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. */ void clear() { unique_lock_t lookup_lock(slot_lookup_mutex_); @@ -860,11 +756,11 @@ class index_dense_gt { } /** - * @brief Erases all members from index, closing files, and returning RAM to OS. - * - * Will change both `size()` and `capacity()` to zero. - * Will deallocate all threads/contexts. - * If the index is memory-mapped - releases the mapping and the descriptor. + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. */ void reset() { unique_lock_t lookup_lock(slot_lookup_mutex_); @@ -883,12 +779,12 @@ class index_dense_gt { } /** - * @brief Saves serialized binary index representation to a stream. + * @brief Saves serialized binary index representation to a stream. */ template - serialization_result_t save_to_stream(output_callback_at &&output, // + serialization_result_t save_to_stream(output_callback_at&& output, // serialization_config_t config = {}, // - progress_at &&progress = {}) const { + progress_at&& progress = {}) const { serialization_result_t result; std::uint64_t matrix_rows = 0; @@ -917,7 +813,7 @@ class index_dense_gt { // Dump the vectors one after another for (std::uint64_t i = 0; i != matrix_rows; ++i) { - byte_t *vector = vectors_lookup_[i]; + byte_t* vector = vectors_lookup_[i]; if (!output(vector, matrix_cols)) return result.failed("Failed to serialize into stream"); } @@ -927,7 +823,7 @@ class index_dense_gt { { index_dense_head_buffer_t buffer; std::memset(buffer, 0, sizeof(buffer)); - index_dense_head_t head {buffer}; + index_dense_head_t head{buffer}; std::memcpy(buffer, default_magic(), std::strlen(default_magic())); // Describe software version @@ -956,7 +852,7 @@ class index_dense_gt { } /** - * @brief Estimate the binary length (in bytes) of the serialized index. + * @brief Estimate the binary length (in bytes) of the serialized index. */ std::size_t serialized_length(serialization_config_t config = {}) const noexcept { std::size_t dimensions_length = 0; @@ -969,15 +865,15 @@ class index_dense_gt { } /** - * @brief Parses the index from file to RAM. - * @param[in] path The path to the file. - * @param[in] config Configuration parameters for imports. - * @return Outcome descriptor explicitly convertible to boolean. + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. */ template - serialization_result_t load_from_stream(input_callback_at &&input, // + serialization_result_t load_from_stream(input_callback_at&& input, // serialization_config_t config = {}, // - progress_at &&progress = {}) { + progress_at&& progress = {}) { // Discard all previous memory allocations of `vectors_tape_allocator_` reset(); @@ -1006,7 +902,7 @@ class index_dense_gt { // Load the vectors one after another vectors_lookup_.resize(matrix_rows); for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { - byte_t *vector = vectors_tape_allocator_.allocate(matrix_cols); + byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); if (!input(vector, matrix_cols)) return result.failed("Failed to read vectors"); vectors_lookup_[slot] = vector; @@ -1019,7 +915,7 @@ class index_dense_gt { if (!input(buffer, sizeof(buffer))) return result.failed("Failed to read the index "); - index_dense_head_t head {buffer}; + index_dense_head_t head{buffer}; if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) return result.failed("Magic header mismatch - the file isn't an index"); @@ -1033,8 +929,10 @@ class index_dense_gt { if (head.kind_compressed_slot != unum::usearch::scalar_kind()) return result.failed("Slot type doesn't match, consider rebuilding"); - metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); config_.multi = head.multi; + metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); } // Pull the actual proximity graph @@ -1049,15 +947,15 @@ class index_dense_gt { } /** - * @brief Parses the index from file, without loading it into RAM. - * @param[in] path The path to the file. - * @param[in] config Configuration parameters for imports. - * @return Outcome descriptor explicitly convertible to boolean. + * @brief Parses the index from file, without loading it into RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. */ template serialization_result_t view(memory_mapped_file_t file, // std::size_t offset = 0, serialization_config_t config = {}, // - progress_at &&progress = {}) { + progress_at&& progress = {}) { // Discard all previous memory allocations of `vectors_tape_allocator_` reset(); @@ -1103,7 +1001,7 @@ class index_dense_gt { std::memcpy(buffer, file.data() + offset, sizeof(buffer)); - index_dense_head_t head {buffer}; + index_dense_head_t head{buffer}; if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) return result.failed("Magic header mismatch - the file isn't an index"); @@ -1117,8 +1015,10 @@ class index_dense_gt { if (head.kind_compressed_slot != unum::usearch::scalar_kind()) return result.failed("Slot type doesn't match, consider rebuilding"); - metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); config_.multi = head.multi; + metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); offset += sizeof(buffer); } @@ -1133,28 +1033,28 @@ class index_dense_gt { vectors_lookup_.resize(matrix_rows); if (!config.exclude_vectors) for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) - vectors_lookup_[slot] = (byte_t *)vectors_buffer.data() + matrix_cols * slot; + vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; reindex_keys_(); return result; } /** - * @brief Saves the index to a file. - * @param[in] path The path to the file. - * @param[in] config Configuration parameters for exports. - * @return Outcome descriptor explicitly convertible to boolean. + * @brief Saves the index to a file. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for exports. + * @return Outcome descriptor explicitly convertible to boolean. */ template serialization_result_t save(output_file_t file, serialization_config_t config = {}, - progress_at &&progress = {}) const { + progress_at&& progress = {}) const { serialization_result_t io_result = file.open_if_not(); if (!io_result) return io_result; serialization_result_t stream_result = save_to_stream( - [&](void const *buffer, std::size_t length) { + [&](void const* buffer, std::size_t length) { io_result = file.write(buffer, length); return !!io_result; }, @@ -1168,21 +1068,21 @@ class index_dense_gt { } /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. */ template serialization_result_t save(memory_mapped_file_t file, // std::size_t offset = 0, // serialization_config_t config = {}, // - progress_at &&progress = {}) const { + progress_at&& progress = {}) const { serialization_result_t io_result = file.open_if_not(); if (!io_result) return io_result; serialization_result_t stream_result = save_to_stream( - [&](void const *buffer, std::size_t length) { + [&](void const* buffer, std::size_t length) { if (offset + length > file.size()) return false; std::memcpy(file.data() + offset, buffer, length); @@ -1195,20 +1095,20 @@ class index_dense_gt { } /** - * @brief Parses the index from file to RAM. - * @param[in] path The path to the file. - * @param[in] config Configuration parameters for imports. - * @return Outcome descriptor explicitly convertible to boolean. + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. */ template - serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at &&progress = {}) { + serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) { serialization_result_t io_result = file.open_if_not(); if (!io_result) return io_result; serialization_result_t stream_result = load_from_stream( - [&](void *buffer, std::size_t length) { + [&](void* buffer, std::size_t length) { io_result = file.read(buffer, length); return !!io_result; }, @@ -1222,21 +1122,21 @@ class index_dense_gt { } /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. */ template serialization_result_t load(memory_mapped_file_t file, // std::size_t offset = 0, // serialization_config_t config = {}, // - progress_at &&progress = {}) { + progress_at&& progress = {}) { serialization_result_t io_result = file.open_if_not(); if (!io_result) return io_result; serialization_result_t stream_result = load_from_stream( - [&](void *buffer, std::size_t length) { + [&](void* buffer, std::size_t length) { if (offset + length > file.size()) return false; std::memcpy(buffer, file.data() + offset, length); @@ -1249,22 +1149,22 @@ class index_dense_gt { } template - serialization_result_t save(char const *file_path, // + serialization_result_t save(char const* file_path, // serialization_config_t config = {}, // - progress_at &&progress = {}) const { + progress_at&& progress = {}) const { return save(output_file_t(file_path), config, std::forward(progress)); } template - serialization_result_t load(char const *file_path, // + serialization_result_t load(char const* file_path, // serialization_config_t config = {}, // - progress_at &&progress = {}) { + progress_at&& progress = {}) { return load(input_file_t(file_path), config, std::forward(progress)); } /** - * @brief Checks if a vector with specified key is present. - * @return `true` if the key is present in the index, `false` otherwise. + * @brief Checks if a vector with specified key is present. + * @return `true` if the key is present in the index, `false` otherwise. */ bool contains(vector_key_t key) const { shared_lock_t lock(slot_lookup_mutex_); @@ -1272,8 +1172,8 @@ class index_dense_gt { } /** - * @brief Count the number of vectors with specified key present. - * @return Zero if nothing is found, a positive integer otherwise. + * @brief Count the number of vectors with specified key present. + * @return Zero if nothing is found, a positive integer otherwise. */ std::size_t count(vector_key_t key) const { shared_lock_t lock(slot_lookup_mutex_); @@ -1281,12 +1181,10 @@ class index_dense_gt { } struct labeling_result_t { - error_t error {}; - std::size_t completed {}; + error_t error{}; + std::size_t completed{}; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } labeling_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); @@ -1294,12 +1192,12 @@ class index_dense_gt { }; /** - * @brief Removes an entry with the specified key from the index. - * @param[in] key The key of the entry to remove. - * @return The ::labeling_result_t indicating the result of the removal operation. - * If the removal was successful, `result.completed` will be `true`. - * If the key was not found in the index, `result.completed` will be `false`. - * If an error occurred during the removal operation, `result.error` will contain an error message. + * @brief Removes an entry with the specified key from the index. + * @param[in] key The key of the entry to remove. + * @return The ::labeling_result_t indicating the result of the removal operation. + * If the removal was successful, `result.completed` will be `true`. + * If the key was not found in the index, `result.completed` will be `false`. + * If an error occurred during the removal operation, `result.error` will contain an error message. */ labeling_result_t remove(vector_key_t key) { labeling_result_t result; @@ -1331,12 +1229,12 @@ class index_dense_gt { } /** - * @brief Removes multiple entries with the specified keys from the index. - * @param[in] keys_begin The beginning of the keys range. - * @param[in] keys_end The ending of the keys range. - * @return The ::labeling_result_t indicating the result of the removal operation. - * `result.completed` will contain the number of keys that were successfully removed. - * `result.error` will contain an error message if an error occurred during the removal operation. + * @brief Removes multiple entries with the specified keys from the index. + * @param[in] keys_begin The beginning of the keys range. + * @param[in] keys_end The ending of the keys range. + * @return The ::labeling_result_t indicating the result of the removal operation. + * `result.completed` will contain the number of keys that were successfully removed. + * `result.error` will contain an error message if an error occurred during the removal operation. */ template labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { @@ -1376,12 +1274,12 @@ class index_dense_gt { } /** - * @brief Renames an entry with the specified key to a new key. - * @param[in] from The current key of the entry to rename. - * @param[in] to The new key to assign to the entry. - * @return The ::labeling_result_t indicating the result of the rename operation. - * If the rename was successful, `result.completed` will be `true`. - * If the entry with the current key was not found, `result.completed` will be `false`. + * @brief Renames an entry with the specified key to a new key. + * @param[in] from The current key of the entry to rename. + * @param[in] to The new key to assign to the entry. + * @return The ::labeling_result_t indicating the result of the rename operation. + * If the rename was successful, `result.completed` will be `true`. + * If the entry with the current key was not found, `result.completed` will be `false`. */ labeling_result_t rename(vector_key_t from, vector_key_t to) { labeling_result_t result; @@ -1396,7 +1294,7 @@ class index_dense_gt { if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) break; - key_and_slot_t key_and_slot_replacing {to, key_and_slot_removed.slot}; + key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail typed_->at(key_and_slot_removed.slot).key = to; ++result.completed; @@ -1406,15 +1304,15 @@ class index_dense_gt { } /** - * @brief Exports a range of keys for the vectors present in the index. - * @param[out] keys Pointer to the array where the keys will be exported. - * @param[in] offset The number of keys to skip. Useful for pagination. - * @param[in] limit The maximum number of keys to export, that can fit in ::keys. + * @brief Exports a range of keys for the vectors present in the index. + * @param[out] keys Pointer to the array where the keys will be exported. + * @param[in] offset The number of keys to skip. Useful for pagination. + * @param[in] limit The maximum number of keys to export, that can fit in ::keys. */ - void export_keys(vector_key_t *keys, std::size_t offset, std::size_t limit) const { + void export_keys(vector_key_t* keys, std::size_t offset, std::size_t limit) const { shared_lock_t lock(slot_lookup_mutex_); offset = (std::min)(offset, slot_lookup_.size()); - slot_lookup_.for_each([&](key_and_slot_t const &key_and_slot) { + slot_lookup_.for_each([&](key_and_slot_t const& key_and_slot) { if (offset) // Skip the first `offset` entries --offset; @@ -1430,9 +1328,7 @@ class index_dense_gt { index_dense_gt index; error_t error; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } copy_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); @@ -1440,9 +1336,9 @@ class index_dense_gt { }; /** - * @brief Copies the ::index_dense_gt @b with all the data in it. - * @param config The copy configuration (optional). - * @return A copy of the ::index_dense_gt instance. + * @brief Copies the ::index_dense_gt @b with all the data in it. + * @param config The copy configuration (optional). + * @return A copy of the ::index_dense_gt instance. */ copy_result_t copy(index_dense_copy_config_t config = {}) const { copy_result_t result = fork(); @@ -1454,7 +1350,7 @@ class index_dense_gt { return result.failed(std::move(typed_result.error)); // Export the free (removed) slot numbers - index_dense_gt © = result.index; + index_dense_gt& copy = result.index; if (!copy.free_keys_.reserve(free_keys_.size())) return result.failed(std::move(typed_result.error)); for (std::size_t i = 0; i != free_keys_.size(); ++i) @@ -1479,12 +1375,12 @@ class index_dense_gt { } /** - * @brief Copies the ::index_dense_gt model @b without any data. - * @return A similarly configured ::index_dense_gt instance. + * @brief Copies the ::index_dense_gt model @b without any data. + * @return A similarly configured ::index_dense_gt instance. */ copy_result_t fork() const { copy_result_t result; - index_dense_gt &other = result.index; + index_dense_gt& other = result.index; other.config_ = config_; other.cast_buffer_ = cast_buffer_; @@ -1494,7 +1390,7 @@ class index_dense_gt { other.available_threads_ = available_threads_; other.free_key_ = free_key_; - index_t *raw = index_allocator_t {}.allocate(1); + index_t* raw = index_allocator_t{}.allocate(1); if (!raw) return result.failed("Can't allocate the index"); @@ -1504,12 +1400,10 @@ class index_dense_gt { } struct compaction_result_t { - error_t error {}; - std::size_t pruned_edges {}; + error_t error{}; + std::size_t pruned_edges{}; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } compaction_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); @@ -1517,18 +1411,18 @@ class index_dense_gt { }; /** - * @brief Performs compaction on the index, pruning links to removed entries. - * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. - * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. - * @return The ::compaction_result_t indicating the result of the compaction operation. - * `result.pruned_edges` will contain the number of edges that were removed. - * `result.error` will contain an error message if an error occurred during the compaction operation. + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. */ template - compaction_result_t isolate(executor_at &&executor = executor_at {}, progress_at &&progress = progress_at {}) { + compaction_result_t isolate(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { compaction_result_t result; std::atomic pruned_edges; - auto disallow = [&](member_cref_t const &member) noexcept { + auto disallow = [&](member_cref_t const& member) noexcept { bool freed = member.key == free_key_; pruned_edges += freed; return freed; @@ -1539,41 +1433,36 @@ class index_dense_gt { } class values_proxy_t { - index_dense_gt const *index_; + index_dense_gt const* index_; public: - values_proxy_t(index_dense_gt const &index) noexcept : index_(&index) { - } - byte_t const *operator[](compressed_slot_t slot) const noexcept { - return index_->vectors_lookup_[slot]; - } - byte_t const *operator[](member_citerator_t it) const noexcept { - return index_->vectors_lookup_[get_slot(it)]; - } + values_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + byte_t const* operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } + byte_t const* operator[](member_citerator_t it) const noexcept { return index_->vectors_lookup_[get_slot(it)]; } }; /** - * @brief Performs compaction on the index, pruning links to removed entries. - * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. - * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. - * @return The ::compaction_result_t indicating the result of the compaction operation. - * `result.pruned_edges` will contain the number of edges that were removed. - * `result.error` will contain an error message if an error occurred during the compaction operation. + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. */ template - compaction_result_t compact(executor_at &&executor = executor_at {}, progress_at &&progress = progress_at {}) { + compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { compaction_result_t result; - std::vector new_vectors_lookup(vectors_lookup_.size()); + std::vector new_vectors_lookup(vectors_lookup_.size()); vectors_tape_allocator_t new_vectors_allocator; auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { - byte_t *new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); - byte_t *old_vector = vectors_lookup_[old_slot]; + byte_t* new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); + byte_t* old_vector = vectors_lookup_[old_slot]; std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); new_vectors_lookup[new_slot] = new_vector; }; - typed_->compact(values_proxy_t {*this}, metric_proxy_t {*this}, track_slot_change, + typed_->compact(values_proxy_t{*this}, metric_proxy_t{*this}, track_slot_change, std::forward(executor), std::forward(progress)); vectors_lookup_ = std::move(new_vectors_lookup); vectors_tape_allocator_ = std::move(new_vectors_allocator); @@ -1586,35 +1475,33 @@ class index_dense_gt { typename executor_at = dummy_executor_t, // typename progress_at = dummy_progress_t // > - join_result_t join( // - index_dense_gt const &women, // - index_join_config_t config = {}, // - man_to_woman_at &&man_to_woman = man_to_woman_at {}, // - woman_to_man_at &&woman_to_man = woman_to_man_at {}, // - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}) const { - - index_dense_gt const &men = *this; - return unum::usearch::join( // - *men.typed_, *women.typed_, // - values_proxy_t {men}, values_proxy_t {women}, // - metric_proxy_t {men}, metric_proxy_t {women}, // - config, // - std::forward(man_to_woman), // - std::forward(woman_to_man), // - std::forward(executor), // + join_result_t join( // + index_dense_gt const& women, // + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) const { + + index_dense_gt const& men = *this; + return unum::usearch::join( // + *men.typed_, *women.typed_, // + values_proxy_t{men}, values_proxy_t{women}, // + metric_proxy_t{men}, metric_proxy_t{women}, // + config, // + std::forward(man_to_woman), // + std::forward(woman_to_man), // + std::forward(executor), // std::forward(progress)); } struct clustering_result_t { - error_t error {}; - std::size_t clusters {}; - std::size_t visited_members {}; - std::size_t computed_distances {}; + error_t error{}; + std::size_t clusters{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; - explicit operator bool() const noexcept { - return !error; - } + explicit operator bool() const noexcept { return !error; } clustering_result_t failed(error_t message) noexcept { error = std::move(message); return std::move(*this); @@ -1622,31 +1509,31 @@ class index_dense_gt { }; /** - * @brief Implements clustering, classifying the given objects (vectors of member keys) - * into a given number of clusters. - * - * @param[in] queries_begin Iterator pointing to the first query. - * @param[in] queries_end Iterator pointing to the last query. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - * @param[in] config Configuration parameters for clustering. - * - * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. - * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. + * @brief Implements clustering, classifying the given objects (vectors of member keys) + * into a given number of clusters. + * + * @param[in] queries_begin Iterator pointing to the first query. + * @param[in] queries_end Iterator pointing to the last query. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + * @param[in] config Configuration parameters for clustering. + * + * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. + * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. */ template < // typename queries_iterator_at, // typename executor_at = dummy_executor_t, // typename progress_at = dummy_progress_t // > - clustering_result_t cluster( // - queries_iterator_at queries_begin, // - queries_iterator_at queries_end, // - index_dense_clustering_config_t config, // - vector_key_t *cluster_keys, // - distance_t *cluster_distances, // - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}) { + clustering_result_t cluster( // + queries_iterator_at queries_begin, // + queries_iterator_at queries_end, // + index_dense_clustering_config_t config, // + vector_key_t* cluster_keys, // + distance_t* cluster_distances, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) { std::size_t const queries_count = queries_end - queries_begin; @@ -1669,19 +1556,15 @@ class index_dense_gt { vector_key_t centroid; vector_key_t merged_into; std::size_t popularity; - byte_t *vector; + byte_t* vector; }; - auto centroid_id = [](cluster_t const &a, cluster_t const &b) { - return a.centroid < b.centroid; - }; - auto higher_popularity = [](cluster_t const &a, cluster_t const &b) { - return a.popularity > b.popularity; - }; + auto centroid_id = [](cluster_t const& a, cluster_t const& b) { return a.centroid < b.centroid; }; + auto higher_popularity = [](cluster_t const& a, cluster_t const& b) { return a.popularity > b.popularity; }; std::atomic visited_members(0); std::atomic computed_distances(0); - std::atomic atomic_error {nullptr}; + std::atomic atomic_error{nullptr}; using dynamic_allocator_traits_t = std::allocator_traits; using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; @@ -1748,7 +1631,7 @@ class index_dense_gt { merge_nearby_clusters: if (unique_clusters > config.max_clusters) { - cluster_t &merge_source = clusters[unique_clusters - 1]; + cluster_t& merge_source = clusters[unique_clusters - 1]; std::size_t merge_target_idx = 0; distance_t merge_distance = std::numeric_limits::max(); @@ -1780,8 +1663,8 @@ class index_dense_gt { std::sort(clusters.data(), clusters_end, centroid_id); executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { - vector_key_t &cluster_key = cluster_keys[query_idx]; - distance_t &cluster_distance = cluster_distances[query_idx]; + vector_key_t& cluster_key = cluster_keys[query_idx]; + distance_t& cluster_distance = cluster_distances[query_idx]; // Recursively trace replacements of that cluster while (true) { @@ -1808,7 +1691,7 @@ class index_dense_gt { private: struct thread_lock_t { - index_dense_gt const &parent; + index_dense_gt const& parent; std::size_t thread_id; bool engaged; @@ -1837,18 +1720,18 @@ class index_dense_gt { template add_result_t add_( // - vector_key_t key, scalar_at const *vector, // - std::size_t thread, bool force_vector_copy, cast_t const &cast) { + vector_key_t key, scalar_at const* vector, // + std::size_t thread, bool force_vector_copy, cast_t const& cast) { if (!multi() && contains(key)) - return add_result_t {}.failed("Duplicate keys not allowed in high-level wrappers"); + return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); bool copy_vector = !config_.exclude_vectors || force_vector_copy; - byte_t const *vector_data = reinterpret_cast(vector); + byte_t const* vector_data = reinterpret_cast(vector); { - byte_t *casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; bool casted = cast(vector_data, dimensions(), casted_data); if (casted) vector_data = casted_data, copy_vector = true; @@ -1865,20 +1748,20 @@ class index_dense_gt { bool reuse_node = free_slot != default_free_value(); auto on_success = [&](member_ref_t member) { unique_lock_t slot_lock(slot_lookup_mutex_); - slot_lookup_.try_emplace(key_and_slot_t {key, static_cast(member.slot)}); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); if (copy_vector) { if (!reuse_node) vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); } else - vectors_lookup_[member.slot] = (byte_t *)vector_data; + vectors_lookup_[member.slot] = (byte_t*)vector_data; }; index_update_config_t update_config; update_config.thread = lock.thread_id; update_config.expansion = config_.expansion_add; - metric_proxy_t metric {*this}; + metric_proxy_t metric{*this}; return reuse_node // ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) : typed_->add(key, vector_data, metric, update_config, on_success); @@ -1886,14 +1769,14 @@ class index_dense_gt { template search_result_t search_( // - scalar_at const *vector, std::size_t wanted, // - std::size_t thread, bool exact, cast_t const &cast) const { + scalar_at const* vector, std::size_t wanted, // + std::size_t thread, bool exact, cast_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); - byte_t const *vector_data = reinterpret_cast(vector); + byte_t const* vector_data = reinterpret_cast(vector); { - byte_t *casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; bool casted = cast(vector_data, dimensions(), casted_data); if (casted) vector_data = casted_data; @@ -1904,22 +1787,20 @@ class index_dense_gt { search_config.expansion = config_.expansion_search; search_config.exact = exact; - auto allow = [=](member_cref_t const &member) noexcept { - return member.key != free_key_; - }; - return typed_->search(vector_data, wanted, metric_proxy_t {*this}, search_config, allow); + auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); } template cluster_result_t cluster_( // - scalar_at const *vector, std::size_t level, // - std::size_t thread, cast_t const &cast) const { + scalar_at const* vector, std::size_t level, // + std::size_t thread, cast_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); - byte_t const *vector_data = reinterpret_cast(vector); + byte_t const* vector_data = reinterpret_cast(vector); { - byte_t *casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; bool casted = cast(vector_data, dimensions(), casted_data); if (casted) vector_data = casted_data; @@ -1929,22 +1810,20 @@ class index_dense_gt { cluster_config.thread = lock.thread_id; cluster_config.expansion = config_.expansion_search; - auto allow = [=](member_cref_t const &member) noexcept { - return member.key != free_key_; - }; - return typed_->cluster(vector_data, level, metric_proxy_t {*this}, cluster_config, allow); + auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; + return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); } template aggregated_distances_t distance_between_( // - vector_key_t key, scalar_at const *vector, // - std::size_t thread, cast_t const &cast) const { + vector_key_t key, scalar_at const* vector, // + std::size_t thread, cast_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); - byte_t const *vector_data = reinterpret_cast(vector); + byte_t const* vector_data = reinterpret_cast(vector); { - byte_t *casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; bool casted = cast(vector_data, dimensions(), casted_data); if (casted) vector_data = casted_data; @@ -1964,8 +1843,8 @@ class index_dense_gt { while (key_range.first != key_range.second) { key_and_slot_t key_and_slot = *key_range.first; - byte_t const *a_vector = vectors_lookup_[key_and_slot.slot]; - byte_t const *b_vector = vector_data; + byte_t const* a_vector = vectors_lookup_[key_and_slot.slot]; + byte_t const* b_vector = vector_data; distance_t a_b_distance = metric_(a_vector, b_vector); result.mean += a_b_distance; @@ -2007,12 +1886,12 @@ class index_dense_gt { if (member.key == free_key_) free_keys_.push(static_cast(i)); else if (config_.enable_key_lookups) - slot_lookup_.try_emplace(key_and_slot_t {vector_key_t(member.key), static_cast(i)}); + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); } } template - std::size_t get_(vector_key_t key, scalar_at *reconstructed, std::size_t vectors_limit, cast_t const &cast) const { + std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const { if (!multi()) { compressed_slot_t slot; @@ -2025,8 +1904,8 @@ class index_dense_gt { slot = (*it).slot; } // Export the entry - byte_t const *punned_vector = reinterpret_cast(vectors_lookup_[slot]); - bool casted = cast(punned_vector, dimensions(), (byte_t *)reconstructed); + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + bool casted = cast(punned_vector, dimensions(), (byte_t*)reconstructed); if (!casted) std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); return true; @@ -2038,8 +1917,8 @@ class index_dense_gt { begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) { // compressed_slot_t slot = (*begin).slot; - byte_t const *punned_vector = reinterpret_cast(vectors_lookup_[slot]); - byte_t *reconstructed_vector = (byte_t *)reconstructed + metric_.bytes_per_vector() * count_exported; + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported; bool casted = cast(punned_vector, dimensions(), reconstructed_vector); if (!casted) std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); @@ -2048,39 +1927,32 @@ class index_dense_gt { } } - template - static casts_t make_casts_() { + template static casts_t make_casts_() { casts_t result; - result.from_b1x8 = cast_gt {}; - result.from_i8 = cast_gt {}; - result.from_f16 = cast_gt {}; - result.from_f32 = cast_gt {}; - result.from_f64 = cast_gt {}; + result.from_b1x8 = cast_gt{}; + result.from_i8 = cast_gt{}; + result.from_f16 = cast_gt{}; + result.from_f32 = cast_gt{}; + result.from_f64 = cast_gt{}; - result.to_b1x8 = cast_gt {}; - result.to_i8 = cast_gt {}; - result.to_f16 = cast_gt {}; - result.to_f32 = cast_gt {}; - result.to_f64 = cast_gt {}; + result.to_b1x8 = cast_gt{}; + result.to_i8 = cast_gt{}; + result.to_f16 = cast_gt{}; + result.to_f32 = cast_gt{}; + result.to_f64 = cast_gt{}; return result; } static casts_t make_casts_(scalar_kind_t scalar_kind) { switch (scalar_kind) { - case scalar_kind_t::f64_k: - return make_casts_(); - case scalar_kind_t::f32_k: - return make_casts_(); - case scalar_kind_t::f16_k: - return make_casts_(); - case scalar_kind_t::i8_k: - return make_casts_(); - case scalar_kind_t::b1x8_k: - return make_casts_(); - default: - return {}; + case scalar_kind_t::f64_k: return make_casts_(); + case scalar_kind_t::f32_k: return make_casts_(); + case scalar_kind_t::f16_k: return make_casts_(); + case scalar_kind_t::i8_k: return make_casts_(); + case scalar_kind_t::b1x8_k: return make_casts_(); + default: return {}; } } }; @@ -2111,14 +1983,14 @@ template < // typename progress_at = dummy_progress_t // > static join_result_t join( // - index_dense_gt const &men, // - index_dense_gt const &women, // - - index_join_config_t config = {}, // - man_to_woman_at &&man_to_woman = man_to_woman_at {}, // - woman_to_man_at &&woman_to_man = woman_to_man_at {}, // - executor_at &&executor = executor_at {}, // - progress_at &&progress = progress_at {}) noexcept { + index_dense_gt const& men, // + index_dense_gt const& women, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { return men.join( // women, config, // diff --git a/src/include/usearch/index_plugins.hpp b/src/include/usearch/index_plugins.hpp index 2c95d7f..57e79a8 100644 --- a/src/include/usearch/index_plugins.hpp +++ b/src/include/usearch/index_plugins.hpp @@ -153,8 +153,7 @@ enum class prefetching_kind_t { io_uring_k, }; -template -scalar_kind_t scalar_kind() noexcept { +template scalar_kind_t scalar_kind() noexcept { if (std::is_same()) return scalar_kind_t::b1x8_k; if (std::is_same()) @@ -188,109 +187,72 @@ scalar_kind_t scalar_kind() noexcept { return scalar_kind_t::unknown_k; } -template -at angle_to_radians(at angle) noexcept { - return angle * at(3.14159265358979323846) / at(180); -} +template at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); } -template -at square(at value) noexcept { - return value * value; -} +template at square(at value) noexcept { return value * value; } -template -inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { +template inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { return comp(v, lo) ? lo : comp(hi, v) ? hi : v; } -template -inline at clamp(at v, at lo, at hi) noexcept { - return usearch::clamp(v, lo, hi, std::less {}); +template inline at clamp(at v, at lo, at hi) noexcept { + return usearch::clamp(v, lo, hi, std::less{}); } -inline bool str_equals(char const *begin, std::size_t len, char const *other_begin) noexcept { +inline bool str_equals(char const* begin, std::size_t len, char const* other_begin) noexcept { std::size_t other_len = std::strlen(other_begin); return len == other_len && std::strncmp(begin, other_begin, len) == 0; } inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { - case scalar_kind_t::f64_k: - return 64; - case scalar_kind_t::f32_k: - return 32; - case scalar_kind_t::f16_k: - return 16; - case scalar_kind_t::i8_k: - return 8; - case scalar_kind_t::b1x8_k: - return 1; - default: - return 0; + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 1; + default: return 0; } } inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { - case scalar_kind_t::f64_k: - return 64; - case scalar_kind_t::f32_k: - return 32; - case scalar_kind_t::f16_k: - return 16; - case scalar_kind_t::i8_k: - return 8; - case scalar_kind_t::b1x8_k: - return 8; - default: - return 0; + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 8; + default: return 0; } } -inline char const *scalar_kind_name(scalar_kind_t scalar_kind) noexcept { +inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { - case scalar_kind_t::f32_k: - return "f32"; - case scalar_kind_t::f16_k: - return "f16"; - case scalar_kind_t::f64_k: - return "f64"; - case scalar_kind_t::i8_k: - return "i8"; - case scalar_kind_t::b1x8_k: - return "b1x8"; - default: - return ""; + case scalar_kind_t::f32_k: return "f32"; + case scalar_kind_t::f16_k: return "f16"; + case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::i8_k: return "i8"; + case scalar_kind_t::b1x8_k: return "b1x8"; + default: return ""; } } -inline char const *metric_kind_name(metric_kind_t metric) noexcept { +inline char const* metric_kind_name(metric_kind_t metric) noexcept { switch (metric) { - case metric_kind_t::unknown_k: - return "unknown"; - case metric_kind_t::ip_k: - return "ip"; - case metric_kind_t::cos_k: - return "cos"; - case metric_kind_t::l2sq_k: - return "l2sq"; - case metric_kind_t::pearson_k: - return "pearson"; - case metric_kind_t::haversine_k: - return "haversine"; - case metric_kind_t::divergence_k: - return "divergence"; - case metric_kind_t::jaccard_k: - return "jaccard"; - case metric_kind_t::hamming_k: - return "hamming"; - case metric_kind_t::tanimoto_k: - return "tanimoto"; - case metric_kind_t::sorensen_k: - return "sorensen"; + case metric_kind_t::unknown_k: return "unknown"; + case metric_kind_t::ip_k: return "ip"; + case metric_kind_t::cos_k: return "cos"; + case metric_kind_t::l2sq_k: return "l2sq"; + case metric_kind_t::pearson_k: return "pearson"; + case metric_kind_t::haversine_k: return "haversine"; + case metric_kind_t::divergence_k: return "divergence"; + case metric_kind_t::jaccard_k: return "jaccard"; + case metric_kind_t::hamming_k: return "hamming"; + case metric_kind_t::tanimoto_k: return "tanimoto"; + case metric_kind_t::sorensen_k: return "sorensen"; } return ""; } -inline expected_gt scalar_kind_from_name(char const *name, std::size_t len) { +inline expected_gt scalar_kind_from_name(char const* name, std::size_t len) { expected_gt parsed; if (str_equals(name, len, "f32")) parsed.result = scalar_kind_t::f32_k; @@ -305,11 +267,11 @@ inline expected_gt scalar_kind_from_name(char const *name, std::s return parsed; } -inline expected_gt scalar_kind_from_name(char const *name) { +inline expected_gt scalar_kind_from_name(char const* name) { return scalar_kind_from_name(name, std::strlen(name)); } -inline expected_gt metric_from_name(char const *name, std::size_t len) { +inline expected_gt metric_from_name(char const* name, std::size_t len) { expected_gt parsed; if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { parsed.result = metric_kind_t::l2sq_k; @@ -335,7 +297,7 @@ inline expected_gt metric_from_name(char const *name, std::size_t return parsed; } -inline expected_gt metric_from_name(char const *name) { +inline expected_gt metric_from_name(char const* name) { return metric_from_name(name, std::strlen(name)); } @@ -366,84 +328,52 @@ inline std::uint16_t f32_to_f16(float f32) noexcept { * agnostic in-software implementation. */ class f16_bits_t { - std::uint16_t uint16_ {}; + std::uint16_t uint16_{}; public: - inline f16_bits_t() noexcept : uint16_(0) { - } - inline f16_bits_t(f16_bits_t &&) = default; - inline f16_bits_t &operator=(f16_bits_t &&) = default; - inline f16_bits_t(f16_bits_t const &) = default; - inline f16_bits_t &operator=(f16_bits_t const &) = default; + inline f16_bits_t() noexcept : uint16_(0) {} + inline f16_bits_t(f16_bits_t&&) = default; + inline f16_bits_t& operator=(f16_bits_t&&) = default; + inline f16_bits_t(f16_bits_t const&) = default; + inline f16_bits_t& operator=(f16_bits_t const&) = default; - inline operator float() const noexcept { - return f16_to_f32(uint16_); - } - inline explicit operator bool() const noexcept { - return f16_to_f32(uint16_) > 0.5f; - } + inline operator float() const noexcept { return f16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } inline f16_bits_t(i8_converted_t) noexcept; - inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) { - } - inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) { - } - inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(v)) { - } - - inline f16_bits_t operator+(f16_bits_t other) const noexcept { - return {float(*this) + float(other)}; - } - inline f16_bits_t operator-(f16_bits_t other) const noexcept { - return {float(*this) - float(other)}; - } - inline f16_bits_t operator*(f16_bits_t other) const noexcept { - return {float(*this) * float(other)}; - } - inline f16_bits_t operator/(f16_bits_t other) const noexcept { - return {float(*this) / float(other)}; - } - inline f16_bits_t operator+(float other) const noexcept { - return {float(*this) + other}; - } - inline f16_bits_t operator-(float other) const noexcept { - return {float(*this) - other}; - } - inline f16_bits_t operator*(float other) const noexcept { - return {float(*this) * other}; - } - inline f16_bits_t operator/(float other) const noexcept { - return {float(*this) / other}; - } - inline f16_bits_t operator+(double other) const noexcept { - return {float(*this) + other}; - } - inline f16_bits_t operator-(double other) const noexcept { - return {float(*this) - other}; - } - inline f16_bits_t operator*(double other) const noexcept { - return {float(*this) * other}; - } - inline f16_bits_t operator/(double other) const noexcept { - return {float(*this) / other}; - } - - inline f16_bits_t &operator+=(float v) noexcept { + inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(v)) {} + + inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline f16_bits_t operator+(float other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(float other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(float other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(float other) const noexcept { return {float(*this) / other}; } + inline f16_bits_t operator+(double other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(double other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(double other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(double other) const noexcept { return {float(*this) / other}; } + + inline f16_bits_t& operator+=(float v) noexcept { uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); return *this; } - inline f16_bits_t &operator-=(float v) noexcept { + inline f16_bits_t& operator-=(float v) noexcept { uint16_ = f32_to_f16(v - f16_to_f32(uint16_)); return *this; } - inline f16_bits_t &operator*=(float v) noexcept { + inline f16_bits_t& operator*=(float v) noexcept { uint16_ = f32_to_f16(v * f16_to_f32(uint16_)); return *this; } - inline f16_bits_t &operator/=(float v) noexcept { + inline f16_bits_t& operator/=(float v) noexcept { uint16_ = f32_to_f16(v / f16_to_f32(uint16_)); return *this; } @@ -454,17 +384,15 @@ class f16_bits_t { * Isn't efficient for small batches, as it recreates the threads on every call. */ class executor_stl_t { - std::size_t threads_count_ {}; + std::size_t threads_count_{}; struct jthread_t { std::thread native_; jthread_t() = default; - jthread_t(jthread_t &&) = default; - jthread_t(jthread_t const &) = delete; - template - jthread_t(callable_at &&func) : native_([=]() { func(); }) { - } + jthread_t(jthread_t&&) = default; + jthread_t(jthread_t const&) = delete; + template jthread_t(callable_at&& func) : native_([=]() { func(); }) {} ~jthread_t() { if (native_.joinable()) @@ -474,27 +402,24 @@ class executor_stl_t { public: /** - * @param threads_count The number of threads to be used for parallel execution. + * @param threads_count The number of threads to be used for parallel execution. */ executor_stl_t(std::size_t threads_count = 0) noexcept - : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) { - } + : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) {} /** - * @return Maximum number of threads available to the executor. + * @return Maximum number of threads available to the executor. */ - std::size_t size() const noexcept { - return threads_count_; - } + std::size_t size() const noexcept { return threads_count_; } /** - * @brief Executes a fixed number of tasks using the specified thread-aware function. - * @param tasks The total number of tasks to be executed. - * @param thread_aware_function The thread-aware function to be called for each thread index and task index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Executes a fixed number of tasks using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void fixed(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept(false) { + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { std::vector threads_pool; std::size_t tasks_per_thread = tasks; std::size_t threads_count = (std::min)(threads_count_, tasks); @@ -513,17 +438,17 @@ class executor_stl_t { } /** - * @brief Executes limited number of tasks using the specified thread-aware function. - * @param tasks The upper bound on the number of tasks. - * @param thread_aware_function The thread-aware function to be called for each thread index and task index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Executes limited number of tasks using the specified thread-aware function. + * @param tasks The upper bound on the number of tasks. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void dynamic(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept(false) { + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { std::vector threads_pool; std::size_t tasks_per_thread = tasks; std::size_t threads_count = (std::min)(threads_count_, tasks); - std::atomic_bool stop {false}; + std::atomic_bool stop{false}; if (threads_count > 1) { tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { @@ -544,12 +469,12 @@ class executor_stl_t { } /** - * @brief Saturates every available thread with the given workload, until they finish. - * @param thread_aware_function The thread-aware function to be called for each thread index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void parallel(thread_aware_function_at &&thread_aware_function) noexcept(false) { + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { if (threads_count_ == 1) return thread_aware_function(0); std::vector threads_pool; @@ -568,27 +493,25 @@ class executor_stl_t { class executor_openmp_t { public: /** - * @param threads_count The number of threads to be used for parallel execution. + * @param threads_count The number of threads to be used for parallel execution. */ executor_openmp_t(std::size_t threads_count = 0) noexcept { omp_set_num_threads(threads_count ? threads_count : std::thread::hardware_concurrency()); } /** - * @return Maximum number of threads available to the executor. + * @return Maximum number of threads available to the executor. */ - std::size_t size() const noexcept { - return omp_get_num_threads(); - } + std::size_t size() const noexcept { return omp_get_num_threads(); } /** - * @brief Executes tasks in bulk using the specified thread-aware function. - * @param tasks The total number of tasks to be executed. - * @param thread_aware_function The thread-aware function to be called for each thread index and task index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void fixed(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept(false) { + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { #pragma omp parallel for schedule(dynamic, 1) for (std::size_t i = 0; i != tasks; ++i) { thread_aware_function(omp_get_thread_num(), i); @@ -596,13 +519,13 @@ class executor_openmp_t { } /** - * @brief Executes tasks in bulk using the specified thread-aware function. - * @param tasks The total number of tasks to be executed. - * @param thread_aware_function The thread-aware function to be called for each thread index and task index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void dynamic(std::size_t tasks, thread_aware_function_at &&thread_aware_function) noexcept(false) { + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { // OpenMP cancellation points are not yet available on most platforms, and require // the `OMP_CANCELLATION` environment variable to be set. // http://jakascorner.com/blog/2016/08/omp-cancel.html @@ -615,7 +538,7 @@ class executor_openmp_t { // } // } // } - std::atomic_bool stop {false}; + std::atomic_bool stop{false}; #pragma omp parallel for schedule(dynamic, 1) shared(stop) for (std::size_t i = 0; i != tasks; ++i) { if (!stop.load(std::memory_order_relaxed) && !thread_aware_function(omp_get_thread_num(), i)) @@ -624,12 +547,12 @@ class executor_openmp_t { } /** - * @brief Saturates every available thread with the given workload, until they finish. - * @param thread_aware_function The thread-aware function to be called for each thread index. - * @throws If an exception occurs during execution of the thread-aware function. + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. */ template - void parallel(thread_aware_function_at &&thread_aware_function) noexcept(false) { + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { #pragma omp parallel { thread_aware_function(omp_get_thread_num()); } } @@ -651,16 +574,13 @@ class aligned_allocator_gt { public: using value_type = element_at; using size_type = std::size_t; - using pointer = element_at *; - using const_pointer = element_at const *; - template - struct rebind { + using pointer = element_at*; + using const_pointer = element_at const*; + template struct rebind { using other = aligned_allocator_gt; }; - constexpr std::size_t alignment() const { - return alignment_ak; - } + constexpr std::size_t alignment() const { return alignment_ak; } pointer allocate(size_type length) const { std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); @@ -688,25 +608,23 @@ using aligned_allocator_t = aligned_allocator_gt<>; class page_allocator_t { public: - static constexpr std::size_t page_size() { - return 4096; - } + static constexpr std::size_t page_size() { return 4096; } /** - * @brief Allocates an @b uninitialized block of memory of the specified size. - * @param count_bytes The number of bytes to allocate. - * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. */ - byte_t *allocate(std::size_t count_bytes) const noexcept { + byte_t* allocate(std::size_t count_bytes) const noexcept { count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); #if defined(USEARCH_DEFINED_WINDOWS) - return (byte_t *)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); + return (byte_t*)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); #else - return (byte_t *)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); + return (byte_t*)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); #endif } - void deallocate(byte_t *page_pointer, std::size_t count_bytes) const noexcept { + void deallocate(byte_t* page_pointer, std::size_t count_bytes) const noexcept { #if defined(USEARCH_DEFINED_WINDOWS) ::VirtualFree(page_pointer, 0, MEM_RELEASE); #else @@ -723,22 +641,17 @@ class page_allocator_t { * Using this memory allocator won't affect your overall speed much, as that is not the bottleneck. * However, it can drastically improve memory usage especially for huge indexes of small vectors. */ -template -class memory_mapping_allocator_gt { +template class memory_mapping_allocator_gt { - static constexpr std::size_t min_capacity() { - return 1024 * 1024 * 4; - } - static constexpr std::size_t capacity_multiplier() { - return 2; - } + static constexpr std::size_t min_capacity() { return 1024 * 1024 * 4; } + static constexpr std::size_t capacity_multiplier() { return 2; } static constexpr std::size_t head_size() { /// Pointer to the the previous arena and the size of the current one. - return divide_round_up(sizeof(byte_t *) + sizeof(std::size_t)) * alignment_ak; + return divide_round_up(sizeof(byte_t*) + sizeof(std::size_t)) * alignment_ak; } std::mutex mutex_; - byte_t *last_arena_ = nullptr; + byte_t* last_arena_ = nullptr; std::size_t last_usage_ = head_size(); std::size_t last_capacity_ = min_capacity(); std::size_t wasted_space_ = 0; @@ -746,16 +659,15 @@ class memory_mapping_allocator_gt { public: using value_type = byte_t; using size_type = std::size_t; - using pointer = byte_t *; - using const_pointer = byte_t const *; + using pointer = byte_t*; + using const_pointer = byte_t const*; memory_mapping_allocator_gt() = default; - memory_mapping_allocator_gt(memory_mapping_allocator_gt &&other) noexcept + memory_mapping_allocator_gt(memory_mapping_allocator_gt&& other) noexcept : last_arena_(exchange(other.last_arena_, nullptr)), last_usage_(exchange(other.last_usage_, 0)), - last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) { - } + last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) {} - memory_mapping_allocator_gt &operator=(memory_mapping_allocator_gt &&other) noexcept { + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt&& other) noexcept { std::swap(last_arena_, other.last_arena_); std::swap(last_usage_, other.last_usage_); std::swap(last_capacity_, other.last_capacity_); @@ -763,21 +675,19 @@ class memory_mapping_allocator_gt { return *this; } - ~memory_mapping_allocator_gt() noexcept { - reset(); - } + ~memory_mapping_allocator_gt() noexcept { reset(); } /** - * @brief Discards all previously allocated memory buffers. + * @brief Discards all previously allocated memory buffers. */ void reset() noexcept { - byte_t *last_arena = last_arena_; + byte_t* last_arena = last_arena_; while (last_arena) { - byte_t *previous_arena = nullptr; - std::memcpy(&previous_arena, last_arena, sizeof(byte_t *)); + byte_t* previous_arena = nullptr; + std::memcpy(&previous_arena, last_arena, sizeof(byte_t*)); std::size_t last_cap = 0; - std::memcpy(&last_cap, last_arena + sizeof(byte_t *), sizeof(std::size_t)); - page_allocator_t {}.deallocate(last_arena, last_cap); + std::memcpy(&last_cap, last_arena + sizeof(byte_t*), sizeof(std::size_t)); + page_allocator_t{}.deallocate(last_arena, last_cap); last_arena = previous_arena; } @@ -789,37 +699,36 @@ class memory_mapping_allocator_gt { } /** - * @brief Copy constructor. - * @note This is a no-op copy constructor since the allocator is not copyable. + * @brief Copy constructor. + * @note This is a no-op copy constructor since the allocator is not copyable. */ - memory_mapping_allocator_gt(memory_mapping_allocator_gt const &) noexcept { - } + memory_mapping_allocator_gt(memory_mapping_allocator_gt const&) noexcept {} /** - * @brief Copy assignment operator. - * @note This is a no-op copy assignment operator since the allocator is not copyable. - * @return Reference to the allocator after the assignment. + * @brief Copy assignment operator. + * @note This is a no-op copy assignment operator since the allocator is not copyable. + * @return Reference to the allocator after the assignment. */ - memory_mapping_allocator_gt &operator=(memory_mapping_allocator_gt const &) noexcept { + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt const&) noexcept { reset(); return *this; } /** - * @brief Allocates an @b uninitialized block of memory of the specified size. - * @param count_bytes The number of bytes to allocate. - * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. */ - inline byte_t *allocate(std::size_t count_bytes) noexcept { + inline byte_t* allocate(std::size_t count_bytes) noexcept { std::size_t extended_bytes = divide_round_up(count_bytes) * alignment_ak; std::unique_lock lock(mutex_); if (!last_arena_ || (last_usage_ + extended_bytes >= last_capacity_)) { std::size_t new_cap = (std::max)(last_capacity_, ceil2(extended_bytes)) * capacity_multiplier(); - byte_t *new_arena = page_allocator_t {}.allocate(new_cap); + byte_t* new_arena = page_allocator_t{}.allocate(new_cap); if (!new_arena) return nullptr; - std::memcpy(new_arena, &last_arena_, sizeof(byte_t *)); - std::memcpy(new_arena + sizeof(byte_t *), &new_cap, sizeof(std::size_t)); + std::memcpy(new_arena, &last_arena_, sizeof(byte_t*)); + std::memcpy(new_arena + sizeof(byte_t*), &new_cap, sizeof(std::size_t)); wasted_space_ += total_reserved(); last_arena_ = new_arena; @@ -832,8 +741,8 @@ class memory_mapping_allocator_gt { } /** - * @brief Returns the amount of memory used by the allocator across all arenas. - * @return The amount of space in bytes. + * @brief Returns the amount of memory used by the allocator across all arenas. + * @return The amount of space in bytes. */ std::size_t total_allocated() const noexcept { if (!last_arena_) @@ -848,27 +757,21 @@ class memory_mapping_allocator_gt { } /** - * @brief Returns the amount of wasted space due to alignment. - * @return The amount of wasted space in bytes. + * @brief Returns the amount of wasted space due to alignment. + * @return The amount of wasted space in bytes. */ - std::size_t total_wasted() const noexcept { - return wasted_space_; - } + std::size_t total_wasted() const noexcept { return wasted_space_; } /** - * @brief Returns the amount of remaining memory already reserved but not yet used. - * @return The amount of reserved memory in bytes. + * @brief Returns the amount of remaining memory already reserved but not yet used. + * @return The amount of reserved memory in bytes. */ - std::size_t total_reserved() const noexcept { - return last_arena_ ? last_capacity_ - last_usage_ : 0; - } + std::size_t total_reserved() const noexcept { return last_arena_ ? last_capacity_ - last_usage_ : 0; } /** - * @warning The very first memory de-allocation discards all the arenas! + * @warning The very first memory de-allocation discards all the arenas! */ - void deallocate(byte_t * = nullptr, std::size_t = 0) noexcept { - reset(); - } + void deallocate(byte_t* = nullptr, std::size_t = 0) noexcept { reset(); } }; using memory_mapping_allocator_t = memory_mapping_allocator_gt<>; @@ -884,7 +787,7 @@ class unfair_shared_mutex_t { idle_k = 0, writing_k = -1, }; - std::atomic state_ {idle_k}; + std::atomic state_{idle_k}; public: inline void lock() noexcept { @@ -897,9 +800,7 @@ class unfair_shared_mutex_t { } } - inline void unlock() noexcept { - state_.store(idle_k, std::memory_order_release); - } + inline void unlock() noexcept { state_.store(idle_k, std::memory_order_release); } inline void lock_shared() noexcept { std::int32_t raw; @@ -917,12 +818,10 @@ class unfair_shared_mutex_t { } } - inline void unlock_shared() noexcept { - state_.fetch_sub(1, std::memory_order_release); - } + inline void unlock_shared() noexcept { state_.fetch_sub(1, std::memory_order_release); } /** - * @brief Try upgrades the current `lock_shared()` to a unique `lock()` state. + * @brief Try upgrades the current `lock_shared()` to a unique `lock()` state. */ inline bool try_escalate() noexcept { std::int32_t one_read = 1; @@ -930,8 +829,8 @@ class unfair_shared_mutex_t { } /** - * @brief Escalates current lock potentially loosing control in the middle. - * It's a shortcut for `try_escalate`-`unlock_shared`-`lock` trio. + * @brief Escalates current lock potentially loosing control in the middle. + * It's a shortcut for `try_escalate`-`unlock_shared`-`lock` trio. */ inline void unsafe_escalate() noexcept { if (!try_escalate()) { @@ -941,7 +840,7 @@ class unfair_shared_mutex_t { } /** - * @brief Upgrades the current `lock_shared()` to a unique `lock()` state. + * @brief Upgrades the current `lock_shared()` to a unique `lock()` state. */ inline void escalate() noexcept { while (!try_escalate()) @@ -949,7 +848,7 @@ class unfair_shared_mutex_t { } /** - * @brief De-escalation of a previously escalated state. + * @brief De-escalation of a previously escalated state. */ inline void de_escalate() noexcept { std::int32_t one_read = 1; @@ -957,87 +856,62 @@ class unfair_shared_mutex_t { } }; -template -class shared_lock_gt { - mutex_at &mutex_; +template class shared_lock_gt { + mutex_at& mutex_; public: - inline explicit shared_lock_gt(mutex_at &m) noexcept : mutex_(m) { - mutex_.lock_shared(); - } - inline ~shared_lock_gt() noexcept { - mutex_.unlock_shared(); - } + inline explicit shared_lock_gt(mutex_at& m) noexcept : mutex_(m) { mutex_.lock_shared(); } + inline ~shared_lock_gt() noexcept { mutex_.unlock_shared(); } }; /** * @brief Utility class used to cast arrays of one scalar type to another, * avoiding unnecessary conversions. */ -template -struct cast_gt { - inline bool operator()(byte_t const *input, std::size_t dim, byte_t *output) const { - from_scalar_at const *typed_input = reinterpret_cast(input); - to_scalar_at *typed_output = reinterpret_cast(output); - auto converter = [](from_scalar_at from) { - return to_scalar_at(from); - }; +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; std::transform(typed_input, typed_input + dim, typed_output, converter); return true; } }; -template <> -struct cast_gt { - bool operator()(byte_t const *, std::size_t, byte_t *) const { - return false; - } +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; -template <> -struct cast_gt { - bool operator()(byte_t const *, std::size_t, byte_t *) const { - return false; - } +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; -template <> -struct cast_gt { - bool operator()(byte_t const *, std::size_t, byte_t *) const { - return false; - } +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; -template <> -struct cast_gt { - bool operator()(byte_t const *, std::size_t, byte_t *) const { - return false; - } +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; -template <> -struct cast_gt { - bool operator()(byte_t const *, std::size_t, byte_t *) const { - return false; - } +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; -template -struct cast_gt { - inline bool operator()(byte_t const *input, std::size_t dim, byte_t *output) const { - from_scalar_at const *typed_input = reinterpret_cast(input); - unsigned char *typed_output = reinterpret_cast(output); +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + unsigned char* typed_output = reinterpret_cast(output); for (std::size_t i = 0; i != dim; ++i) typed_output[i / CHAR_BIT] |= bool(typed_input[i]) ? (128 >> (i & (CHAR_BIT - 1))) : 0; return true; } }; -template -struct cast_gt { - inline bool operator()(byte_t const *input, std::size_t dim, byte_t *output) const { - unsigned char const *typed_input = reinterpret_cast(input); - to_scalar_at *typed_output = reinterpret_cast(output); +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + unsigned char const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); for (std::size_t i = 0; i != dim; ++i) typed_output[i] = bool(typed_input[i / CHAR_BIT] & (128 >> (i & (CHAR_BIT - 1)))); return true; @@ -1049,86 +923,57 @@ struct cast_gt { * values within [-1,1] range, quantized to integers [-100,100]. */ class i8_converted_t { - std::int8_t int8_ {}; + std::int8_t int8_{}; public: - constexpr static float divisor_k = 100.f; + constexpr static f32_t divisor_k = 100.f; constexpr static std::int8_t min_k = -100; constexpr static std::int8_t max_k = 100; - inline i8_converted_t() noexcept : int8_(0) { - } - inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) { - } + inline i8_converted_t() noexcept : int8_(0) {} + inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) {} - inline i8_converted_t(i8_converted_t &&) = default; - inline i8_converted_t &operator=(i8_converted_t &&) = default; - inline i8_converted_t(i8_converted_t const &) = default; - inline i8_converted_t &operator=(i8_converted_t const &) = default; + inline i8_converted_t(i8_converted_t&&) = default; + inline i8_converted_t& operator=(i8_converted_t&&) = default; + inline i8_converted_t(i8_converted_t const&) = default; + inline i8_converted_t& operator=(i8_converted_t const&) = default; - inline operator float() const noexcept { - return float(int8_) / divisor_k; - } - inline operator f16_t() const noexcept { - return float(int8_) / divisor_k; - } - inline operator double() const noexcept { - return double(int8_) / divisor_k; - } - inline explicit operator bool() const noexcept { - return int8_ > (max_k / 2); - } - inline explicit operator std::int8_t() const noexcept { - return int8_; - } - inline explicit operator std::int16_t() const noexcept { - return int8_; - } - inline explicit operator std::int32_t() const noexcept { - return int8_; - } - inline explicit operator std::int64_t() const noexcept { - return int8_; - } + inline operator f16_t() const noexcept { return static_cast(f32_t(int8_) / divisor_k); } + inline operator f32_t() const noexcept { return f32_t(int8_) / divisor_k; } + inline operator f64_t() const noexcept { return f64_t(int8_) / divisor_k; } + inline explicit operator bool() const noexcept { return int8_ > (max_k / 2); } + inline explicit operator std::int8_t() const noexcept { return int8_; } + inline explicit operator std::int16_t() const noexcept { return int8_; } + inline explicit operator std::int32_t() const noexcept { return int8_; } + inline explicit operator std::int64_t() const noexcept { return int8_; } inline i8_converted_t(f16_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) { - } - inline i8_converted_t(float v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) { - } - inline i8_converted_t(double v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) { - } + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f32_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f64_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} }; -f16_bits_t::f16_bits_t(i8_converted_t v) noexcept : uint16_(f32_to_f16(v)) { -} +f16_bits_t::f16_bits_t(i8_converted_t v) noexcept : uint16_(f32_to_f16(v)) {} -template <> -struct cast_gt : public cast_gt {}; -template <> -struct cast_gt : public cast_gt {}; -template <> -struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; -template <> -struct cast_gt : public cast_gt {}; -template <> -struct cast_gt : public cast_gt {}; -template <> -struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; /** * @brief Inner (Dot) Product distance. */ -template -struct metric_ip_gt { +template struct metric_ip_gt { using scalar_t = scalar_at; using result_t = result_at; - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t dim) const noexcept { - result_t ab {}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab) #elif defined(USEARCH_DEFINED_CLANG) @@ -1148,13 +993,12 @@ struct metric_ip_gt { * Unless you are running on an tiny embedded platform, this metric * is recommended over `::metric_ip_gt` for low-precision scalars. */ -template -struct metric_cos_gt { +template struct metric_cos_gt { using scalar_t = scalar_at; using result_t = result_at; - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t dim) const noexcept { - result_t ab {}, a2 {}, b2 {}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}, a2{}, b2{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab, a2, b2) #elif defined(USEARCH_DEFINED_CLANG) @@ -1162,10 +1006,11 @@ struct metric_cos_gt { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) - ab += result_t(a[i]) * result_t(b[i]), // - a2 += square(a[i]), // - b2 += square(b[i]); + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab += ai * bi, a2 += square(ai), b2 += square(bi); + } result_t result_if_zero[2][2]; result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); @@ -1179,13 +1024,12 @@ struct metric_cos_gt { * @brief Squared Euclidean (L2) distance. * Square root is avoided at the end, as it won't affect the ordering. */ -template -struct metric_l2sq_gt { +template struct metric_l2sq_gt { using scalar_t = scalar_at; using result_t = result_at; - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t dim) const noexcept { - result_t ab_deltas_sq {}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab_deltas_sq{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab_deltas_sq) #elif defined(USEARCH_DEFINED_CLANG) @@ -1193,8 +1037,11 @@ struct metric_l2sq_gt { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) - ab_deltas_sq += square(result_t(a[i]) - result_t(b[i])); + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab_deltas_sq += square(ai - bi); + } return ab_deltas_sq; } }; @@ -1204,8 +1051,7 @@ struct metric_l2sq_gt { * two arrays of integers. An example would be a textual document, * tokenized and hashed into a fixed-capacity bitset. */ -template -struct metric_hamming_gt { +template struct metric_hamming_gt { using scalar_t = scalar_at; using result_t = result_at; static_assert( // @@ -1213,9 +1059,9 @@ struct metric_hamming_gt { (std::is_enum::value && std::is_unsigned::type>::value), "Hamming distance requires unsigned integral words"); - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t words) const noexcept { + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t matches {}; + result_t matches{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : matches) #elif defined(USEARCH_DEFINED_CLANG) @@ -1233,8 +1079,7 @@ struct metric_hamming_gt { * @brief Tanimoto distance is the intersection over bitwise union. * Often used in chemistry and biology to compare molecular fingerprints. */ -template -struct metric_tanimoto_gt { +template struct metric_tanimoto_gt { using scalar_t = scalar_at; using result_t = result_at; static_assert( // @@ -1243,10 +1088,10 @@ struct metric_tanimoto_gt { "Tanimoto distance requires unsigned integral words"); static_assert(std::is_floating_point::value, "Tanimoto distance will be a fraction"); - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t words) const noexcept { + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t and_count {}; - result_t or_count {}; + result_t and_count{}; + result_t or_count{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : and_count, or_count) #elif defined(USEARCH_DEFINED_CLANG) @@ -1254,9 +1099,10 @@ struct metric_tanimoto_gt { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != words; ++i) - and_count += std::bitset(a[i] & b[i]).count(), - or_count += std::bitset(a[i] | b[i]).count(); + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + or_count += std::bitset(a[i] | b[i]).count(); + } return 1 - result_t(and_count) / or_count; } }; @@ -1265,8 +1111,7 @@ struct metric_tanimoto_gt { * @brief Sorensen-Dice or F1 distance is the intersection over bitwise union. * Often used in chemistry and biology to compare molecular fingerprints. */ -template -struct metric_sorensen_gt { +template struct metric_sorensen_gt { using scalar_t = scalar_at; using result_t = result_at; static_assert( // @@ -1275,10 +1120,10 @@ struct metric_sorensen_gt { "Sorensen-Dice distance requires unsigned integral words"); static_assert(std::is_floating_point::value, "Sorensen-Dice distance will be a fraction"); - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t words) const noexcept { + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t and_count {}; - result_t any_count {}; + result_t and_count{}; + result_t any_count{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : and_count, any_count) #elif defined(USEARCH_DEFINED_CLANG) @@ -1286,9 +1131,10 @@ struct metric_sorensen_gt { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != words; ++i) - and_count += std::bitset(a[i] & b[i]).count(), - any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); + } return 1 - 2 * result_t(and_count) / any_count; } }; @@ -1299,17 +1145,16 @@ struct metric_sorensen_gt { * using the IDs of tokens present in them. * Similar to `metric_tanimoto_gt` for dense representations. */ -template -struct metric_jaccard_gt { +template struct metric_jaccard_gt { using scalar_t = scalar_at; using result_t = result_at; static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); inline result_t operator()( // - scalar_t const *a, scalar_t const *b, std::size_t a_length, std::size_t b_length) const noexcept { - result_t intersection {}; - std::size_t i {}; - std::size_t j {}; + scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { + result_t intersection{}; + std::size_t i{}; + std::size_t j{}; while (i != a_length && j != b_length) { intersection += a[i] == b[j]; i += a[i] < b[j]; @@ -1320,16 +1165,22 @@ struct metric_jaccard_gt { }; /** - * @brief Measures Pearson Correlation between two sequences. + * @brief Measures Pearson Correlation between two sequences in a single pass. */ -template -struct metric_pearson_gt { +template struct metric_pearson_gt { using scalar_t = scalar_at; using result_t = result_at; - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t dim) const noexcept { - result_t a_sum {}, b_sum {}, ab_sum {}; - result_t a_sq_sum {}, b_sq_sum {}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + // The correlation coefficient can't be defined for one or zero-dimensional data. + if (dim <= 1) + return 0; + // Conventional Pearson Correlation Coefficient definiton subtracts the mean value of each + // sequence from each element, before dividing them. WikiPedia article suggests a convenient + // single-pass algorithm for calculating sample correlations, though depending on the numbers + // involved, it can sometimes be numerically unstable. + result_t a_sum{}, b_sum{}, ab_sum{}; + result_t a_sq_sum{}, b_sq_sum{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : a_sum, b_sum, ab_sum, a_sq_sum, b_sq_sum) #elif defined(USEARCH_DEFINED_CLANG) @@ -1338,29 +1189,33 @@ struct metric_pearson_gt { #pragma GCC ivdep #endif for (std::size_t i = 0; i != dim; ++i) { - a_sum += result_t(a[i]); - b_sum += result_t(b[i]); - ab_sum += result_t(a[i]) * result_t(b[i]); - a_sq_sum += result_t(a[i]) * result_t(a[i]); - b_sq_sum += result_t(b[i]) * result_t(b[i]); + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + a_sum += ai; + b_sum += bi; + ab_sum += ai * bi; + a_sq_sum += ai * ai; + b_sq_sum += bi * bi; } - result_t denom = std::sqrt((dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum)); - result_t corr = (dim * ab_sum - a_sum * b_sum) / denom; - return -corr; + result_t denom = (dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum); + if (denom == 0) + return 0; + result_t corr = dim * ab_sum - a_sum * b_sum; + denom = std::sqrt(denom); + return -corr / denom; } }; /** * @brief Measures Jensen-Shannon Divergence between two probability distributions. */ -template -struct metric_divergence_gt { +template struct metric_divergence_gt { using scalar_t = scalar_at; using result_t = result_at; - inline result_t operator()(scalar_t const *p, scalar_t const *q, std::size_t dim) const noexcept { - result_t kld_pm {}, kld_qm {}; - scalar_t epsilon = std::numeric_limits::epsilon(); + inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept { + result_t kld_pm{}, kld_qm{}; + result_t epsilon = std::numeric_limits::epsilon(); #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : kld_pm, kld_qm) #elif defined(USEARCH_DEFINED_CLANG) @@ -1369,9 +1224,11 @@ struct metric_divergence_gt { #pragma GCC ivdep #endif for (std::size_t i = 0; i != dim; ++i) { - result_t mi = result_t(p[i] + q[i]) / 2 + epsilon; - kld_pm += p[i] * std::log((p[i] + epsilon) / mi); - kld_qm += q[i] * std::log((q[i] + epsilon) / mi); + result_t pi = static_cast(p[i]); + result_t qi = static_cast(q[i]); + result_t mi = (pi + qi) / 2 + epsilon; + kld_pm += pi * std::log((pi + epsilon) / mi); + kld_qm += qi * std::log((qi + epsilon) / mi); } return (kld_pm + kld_qm) / 2; } @@ -1381,8 +1238,8 @@ struct cos_i8_t { using scalar_t = i8_t; using result_t = f32_t; - inline result_t operator()(i8_t const *a, i8_t const *b, std::size_t dim) const noexcept { - std::int32_t ab {}, a2 {}, b2 {}; + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab{}, a2{}, b2{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab, a2, b2) #elif defined(USEARCH_DEFINED_CLANG) @@ -1391,8 +1248,8 @@ struct cos_i8_t { #pragma GCC ivdep #endif for (std::size_t i = 0; i != dim; i++) { - std::int16_t ai {a[i]}; - std::int16_t bi {b[i]}; + std::int16_t ai{a[i]}; + std::int16_t bi{b[i]}; ab += ai * bi; a2 += square(ai); b2 += square(bi); @@ -1405,8 +1262,8 @@ struct l2sq_i8_t { using scalar_t = i8_t; using result_t = f32_t; - inline result_t operator()(i8_t const *a, i8_t const *b, std::size_t dim) const noexcept { - std::int32_t ab_deltas_sq {}; + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab_deltas_sq{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab_deltas_sq) #elif defined(USEARCH_DEFINED_CLANG) @@ -1416,7 +1273,7 @@ struct l2sq_i8_t { #endif for (std::size_t i = 0; i != dim; i++) ab_deltas_sq += square(std::int16_t(a[i]) - std::int16_t(b[i])); - return ab_deltas_sq; + return static_cast(ab_deltas_sq); } }; @@ -1424,13 +1281,12 @@ struct l2sq_i8_t { * @brief Haversine distance for the shortest distance between two nodes on * the surface of a 3D sphere, defined with latitude and longitude. */ -template -struct metric_haversine_gt { +template struct metric_haversine_gt { using scalar_t = scalar_at; using result_t = result_at; static_assert(!std::is_integral::value, "Latitude and longitude must be floating-node"); - inline result_t operator()(scalar_t const *a, scalar_t const *b, std::size_t = 2) const noexcept { + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { result_t lat_a = a[0], lon_a = a[1]; result_t lat_b = b[0], lon_b = b[1]; @@ -1488,17 +1344,17 @@ class metric_punned_t { public: /** - * @brief Computes the distance between two vectors of fixed length. - * - * ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic. + * @brief Computes the distance between two vectors of fixed length. + * + * ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic. */ - inline result_t operator()(byte_t const *a, byte_t const *b) const noexcept { + inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept { return raw_ptr_(reinterpret_cast(a), reinterpret_cast(b), raw_arg3_, raw_arg4_); } inline metric_punned_t() noexcept = default; - inline metric_punned_t(metric_punned_t const &) noexcept = default; - inline metric_punned_t &operator=(metric_punned_t const &) noexcept = default; + inline metric_punned_t(metric_punned_t const&) noexcept = default; + inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default; inline metric_punned_t( // std::size_t dimensions, // metric_kind_t metric_kind = metric_kind_t::l2sq_k, // @@ -1509,8 +1365,9 @@ class metric_punned_t { #if USEARCH_USE_SIMSIMD if (!configure_with_simsimd()) configure_with_auto_vectorized(); -#endif +#else configure_with_auto_vectorized(); +#endif if (scalar_kind == scalar_kind_t::b1x8_k) raw_arg3_ = raw_arg4_ = divide_round_up(dimensions_); @@ -1528,37 +1385,22 @@ class metric_punned_t { (void)signature; } - inline std::size_t dimensions() const noexcept { - return dimensions_; - } - inline metric_kind_t metric_kind() const noexcept { - return metric_kind_; - } - inline scalar_kind_t scalar_kind() const noexcept { - return scalar_kind_; - } + inline std::size_t dimensions() const noexcept { return dimensions_; } + inline metric_kind_t metric_kind() const noexcept { return metric_kind_; } + inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; } - inline char const *isa_name() const noexcept { + inline char const* isa_name() const noexcept { #if USEARCH_USE_SIMSIMD switch (isa_kind_) { - case simsimd_cap_serial_k: - return "serial"; - case simsimd_cap_arm_neon_k: - return "neon"; - case simsimd_cap_arm_sve_k: - return "sve"; - case simsimd_cap_x86_avx2_k: - return "avx2"; - case simsimd_cap_x86_avx512_k: - return "avx512"; - case simsimd_cap_x86_avx2fp16_k: - return "avx2+f16"; - case simsimd_cap_x86_avx512fp16_k: - return "avx512+f16"; - case simsimd_cap_x86_avx512vpopcntdq_k: - return "avx512+popcnt"; - default: - return "unknown"; + case simsimd_cap_serial_k: return "serial"; + case simsimd_cap_arm_neon_k: return "neon"; + case simsimd_cap_arm_sve_k: return "sve"; + case simsimd_cap_x86_avx2_k: return "avx2"; + case simsimd_cap_x86_avx512_k: return "avx512"; + case simsimd_cap_x86_avx2fp16_k: return "avx2+f16"; + case simsimd_cap_x86_avx512fp16_k: return "avx512+f16"; + case simsimd_cap_x86_avx512vpopcntdq_k: return "avx512+popcnt"; + default: return "unknown"; } #endif return "serial"; @@ -1579,45 +1421,21 @@ class metric_punned_t { simsimd_datatype_t datatype = simsimd_datatype_unknown_k; simsimd_capability_t allowed = simsimd_cap_any_k; switch (metric_kind_) { - case metric_kind_t::ip_k: - kind = simsimd_metric_ip_k; - break; - case metric_kind_t::cos_k: - kind = simsimd_metric_cos_k; - break; - case metric_kind_t::l2sq_k: - kind = simsimd_metric_l2sq_k; - break; - case metric_kind_t::hamming_k: - kind = simsimd_metric_hamming_k; - break; - case metric_kind_t::tanimoto_k: - kind = simsimd_metric_jaccard_k; - break; - case metric_kind_t::jaccard_k: - kind = simsimd_metric_jaccard_k; - break; - default: - break; + case metric_kind_t::ip_k: kind = simsimd_metric_ip_k; break; + case metric_kind_t::cos_k: kind = simsimd_metric_cos_k; break; + case metric_kind_t::l2sq_k: kind = simsimd_metric_l2sq_k; break; + case metric_kind_t::hamming_k: kind = simsimd_metric_hamming_k; break; + case metric_kind_t::tanimoto_k: kind = simsimd_metric_jaccard_k; break; + case metric_kind_t::jaccard_k: kind = simsimd_metric_jaccard_k; break; + default: break; } switch (scalar_kind_) { - case scalar_kind_t::f32_k: - datatype = simsimd_datatype_f32_k; - break; - case scalar_kind_t::f64_k: - datatype = simsimd_datatype_f64_k; - break; - case scalar_kind_t::f16_k: - datatype = simsimd_datatype_f16_k; - break; - case scalar_kind_t::i8_k: - datatype = simsimd_datatype_i8_k; - break; - case scalar_kind_t::b1x8_k: - datatype = simsimd_datatype_b8_k; - break; - default: - break; + case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; + case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; + case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; + case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; + case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; + default: break; } simsimd_metric_punned_t simd_metric = NULL; simsimd_capability_t simd_kind = simsimd_cap_any_k; @@ -1634,90 +1452,50 @@ class metric_punned_t { return configure_with_simsimd(static_capabilities); } #else - bool configure_with_simsimd() noexcept { - return false; - } + bool configure_with_simsimd() noexcept { return false; } #endif void configure_with_auto_vectorized() noexcept { switch (metric_kind_) { case metric_kind_t::ip_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::i8_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } case metric_kind_t::cos_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::i8_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } case metric_kind_t::l2sq_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::i8_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } case metric_kind_t::pearson_k: { switch (scalar_kind_) { - case scalar_kind_t::i8_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; + case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } @@ -1726,15 +1504,9 @@ class metric_punned_t { case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } @@ -1743,30 +1515,17 @@ class metric_punned_t { case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f32_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f64_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - raw_ptr_ = nullptr; - break; + case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: raw_ptr_ = nullptr; break; } break; } case metric_kind_t::jaccard_k: // Equivalent to Tanimoto - case metric_kind_t::tanimoto_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case metric_kind_t::hamming_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case metric_kind_t::sorensen_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - default: - return; + case metric_kind_t::tanimoto_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case metric_kind_t::hamming_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + case metric_kind_t::sorensen_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; + default: return; } } @@ -1776,7 +1535,7 @@ class metric_punned_t { punned_arg_t a_dimensions, punned_arg_t b_dimensions) noexcept { using scalar_t = typename typed_at::scalar_t; (void)b_dimensions; - return typed_at {}((scalar_t const *)a, (scalar_t const *)b, a_dimensions); + return typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions); } }; @@ -1787,41 +1546,29 @@ template // class vectors_view_gt { using scalar_t = scalar_at; - scalar_t const *begin_ {}; - std::size_t dimensions_ {}; - std::size_t count_ {}; - std::size_t stride_bytes_ {}; + scalar_t const* begin_{}; + std::size_t dimensions_{}; + std::size_t count_{}; + std::size_t stride_bytes_{}; public: vectors_view_gt() noexcept = default; - vectors_view_gt(vectors_view_gt const &) noexcept = default; - vectors_view_gt &operator=(vectors_view_gt const &) noexcept = default; + vectors_view_gt(vectors_view_gt const&) noexcept = default; + vectors_view_gt& operator=(vectors_view_gt const&) noexcept = default; - vectors_view_gt(scalar_t const *begin, std::size_t dimensions, std::size_t count = 1) noexcept - : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) { - } + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count = 1) noexcept + : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} - vectors_view_gt(scalar_t const *begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept - : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) { - } + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept + : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} - explicit operator bool() const noexcept { - return begin_; - } - std::size_t size() const noexcept { - return count_; - } - std::size_t dimensions() const noexcept { - return dimensions_; - } - std::size_t stride() const noexcept { - return stride_bytes_; - } - scalar_t const *data() const noexcept { - return begin_; - } - scalar_t const *at(std::size_t i) const noexcept { - return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); + explicit operator bool() const noexcept { return begin_; } + std::size_t size() const noexcept { return count_; } + std::size_t dimensions() const noexcept { return dimensions_; } + std::size_t stride() const noexcept { return stride_bytes_; } + scalar_t const* data() const noexcept { return begin_; } + scalar_t const* at(std::size_t i) const noexcept { + return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); } }; @@ -1853,21 +1600,21 @@ class exact_search_t { template exact_search_results_t operator()( // vectors_view_gt dataset, vectors_view_gt queries, // - std::size_t wanted, metric_punned_t const &metric, // - executor_at &&executor = executor_at {}, progress_at &&progress = progress_at {}) { - return operator()( // - metric, // - reinterpret_cast(dataset.data()), dataset.size(), dataset.stride(), // - reinterpret_cast(queries.data()), queries.size(), queries.stride(), // + std::size_t wanted, metric_punned_t const& metric, // + executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + return operator()( // + metric, // + reinterpret_cast(dataset.data()), dataset.size(), dataset.stride(), // + reinterpret_cast(queries.data()), queries.size(), queries.stride(), // wanted, executor, progress); } template exact_search_results_t operator()( // - byte_t const *dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // - byte_t const *queries_data, std::size_t queries_count, std::size_t queries_stride, // - std::size_t wanted, metric_punned_t const &metric, executor_at &&executor = executor_at {}, - progress_at &&progress = progress_at {}) { + byte_t const* dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // + byte_t const* queries_data, std::size_t queries_count, std::size_t queries_stride, // + std::size_t wanted, metric_punned_t const& metric, executor_at&& executor = executor_at{}, + progress_at&& progress = progress_at{}) { // Allocate temporary memory to store the distance matrix // Previous version didn't need temporary memory, but the performance was much lower. @@ -1879,15 +1626,15 @@ class exact_search_t { if (keys_and_distances.size() < tasks_count * 2) return {}; - exact_offset_and_distance_t *keys_and_distances_per_dataset = keys_and_distances.data(); - exact_offset_and_distance_t *keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; + exact_offset_and_distance_t* keys_and_distances_per_dataset = keys_and_distances.data(); + exact_offset_and_distance_t* keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; // ยง1. Compute distances in a data-parallel fashion - std::atomic processed {0}; + std::atomic processed{0}; executor.dynamic(dataset_count, [&](std::size_t thread_idx, std::size_t dataset_idx) { - byte_t const *dataset = dataset_data + dataset_idx * dataset_stride; + byte_t const* dataset = dataset_data + dataset_idx * dataset_stride; for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { - byte_t const *query = queries_data + query_idx * queries_stride; + byte_t const* query = queries_data + query_idx * queries_stride; auto distance = metric(dataset, query); std::size_t task_idx = queries_count * dataset_idx + query_idx; keys_and_distances_per_dataset[task_idx].offset = static_cast(dataset_idx); @@ -1956,47 +1703,43 @@ class flat_hash_multi_set_gt { using equals_t = equals_at; using allocator_t = allocator_at; - static constexpr std::size_t slots_per_bucket() { - return 64; - } + static constexpr std::size_t slots_per_bucket() { return 64; } static constexpr std::size_t bytes_per_bucket() { return slots_per_bucket() * sizeof(element_t) + sizeof(bucket_header_t); } private: struct bucket_header_t { - std::uint64_t populated {}; - std::uint64_t deleted {}; + std::uint64_t populated{}; + std::uint64_t deleted{}; }; - char *data_ = nullptr; + char* data_ = nullptr; std::size_t buckets_ = 0; std::size_t populated_slots_ = 0; /// @brief Number of slots std::size_t capacity_slots_ = 0; struct slot_ref_t { - bucket_header_t &header; + bucket_header_t& header; std::uint64_t mask; - element_t &element; + element_t& element; }; - slot_ref_t slot_ref(char *data, std::size_t slot_index) const noexcept { + slot_ref_t slot_ref(char* data, std::size_t slot_index) const noexcept { std::size_t bucket_index = slot_index / slots_per_bucket(); std::size_t in_bucket_index = slot_index % slots_per_bucket(); auto bucket_pointer = data + bytes_per_bucket() * bucket_index; auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; return { - *reinterpret_cast(bucket_pointer), + *reinterpret_cast(bucket_pointer), static_cast(1ull) << in_bucket_index, - *reinterpret_cast(slot_pointer), + *reinterpret_cast(slot_pointer), }; } - slot_ref_t slot_ref(std::size_t slot_index) const noexcept { - return slot_ref(data_, slot_index); - } + slot_ref_t slot_ref(std::size_t slot_index) const noexcept { return slot_ref(data_, slot_index); } - bool populate_slot(slot_ref_t slot, element_t const &new_element) { + bool populate_slot(slot_ref_t slot, element_t const& new_element) { if (slot.header.populated & slot.mask) { slot.element = new_element; slot.header.deleted &= ~slot.mask; @@ -2009,20 +1752,13 @@ class flat_hash_multi_set_gt { } public: - std::size_t size() const noexcept { - return populated_slots_; - } - std::size_t capacity() const noexcept { - return capacity_slots_; - } + std::size_t size() const noexcept { return populated_slots_; } + std::size_t capacity() const noexcept { return capacity_slots_; } - flat_hash_multi_set_gt() noexcept { - } - ~flat_hash_multi_set_gt() noexcept { - reset(); - } + flat_hash_multi_set_gt() noexcept {} + ~flat_hash_multi_set_gt() noexcept { reset(); } - flat_hash_multi_set_gt(flat_hash_multi_set_gt const &other) { + flat_hash_multi_set_gt(flat_hash_multi_set_gt const& other) { // On Windows allocating a zero-size array would fail if (!other.buckets_) { @@ -2031,7 +1767,7 @@ class flat_hash_multi_set_gt { } // Allocate new memory - data_ = (char *)allocator_t {}.allocate(other.buckets_ * bytes_per_bucket()); + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); if (!data_) throw std::bad_alloc(); @@ -2053,7 +1789,7 @@ class flat_hash_multi_set_gt { } } - flat_hash_multi_set_gt &operator=(flat_hash_multi_set_gt const &other) { + flat_hash_multi_set_gt& operator=(flat_hash_multi_set_gt const& other) { // On Windows allocating a zero-size array would fail if (!other.buckets_) { @@ -2068,10 +1804,10 @@ class flat_hash_multi_set_gt { // Clear existing data clear(); if (data_) - allocator_t {}.deallocate(data_, buckets_ * bytes_per_bucket()); + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); // Allocate new memory - data_ = (char *)allocator_t {}.allocate(other.buckets_ * bytes_per_bucket()); + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); if (!data_) throw std::bad_alloc(); @@ -2104,14 +1840,15 @@ class flat_hash_multi_set_gt { } // Reset populated slots count - std::memset(data_, 0, buckets_ * bytes_per_bucket()); + if (data_) + std::memset(data_, 0, buckets_ * bytes_per_bucket()); populated_slots_ = 0; } void reset() noexcept { clear(); // Clear all elements if (data_) - allocator_t {}.deallocate(data_, buckets_ * bytes_per_bucket()); + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); buckets_ = 0; populated_slots_ = 0; capacity_slots_ = 0; @@ -2128,7 +1865,7 @@ class flat_hash_multi_set_gt { std::size_t new_bytes = new_buckets * bytes_per_bucket(); // Allocate new memory - char *new_data = (char *)allocator_t {}.allocate(new_bytes); + char* new_data = (char*)allocator_t{}.allocate(new_bytes); if (!new_data) return false; @@ -2160,7 +1897,7 @@ class flat_hash_multi_set_gt { // Deallocate old data and update pointers and sizes if (data_) - allocator_t {}.deallocate(data_, buckets_ * bytes_per_bucket()); + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); data_ = new_data; buckets_ = new_buckets; capacity_slots_ = new_slots; @@ -2168,22 +1905,20 @@ class flat_hash_multi_set_gt { return true; } - template - class equal_iterator_gt { + template class equal_iterator_gt { public: using iterator_category = std::forward_iterator_tag; using value_type = element_t; using difference_type = std::ptrdiff_t; - using pointer = element_t *; - using reference = element_t &; + using pointer = element_t*; + using reference = element_t&; - equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt *parent, query_at const &query, - equals_t const &equals) - : index_(index), parent_(parent), query_(query), equals_(equals) { - } + equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt* parent, query_at const& query, + equals_t const& equals) + : index_(index), parent_(parent), query_(query), equals_(equals) {} // Pre-increment - equal_iterator_gt &operator++() { + equal_iterator_gt& operator++() { do { index_ = (index_ + 1) & (parent_->capacity_slots_ - 1); } while (!equals_(parent_->slot_ref(index_).element, query_) && @@ -2197,38 +1932,32 @@ class flat_hash_multi_set_gt { return temp; } - reference operator*() { - return parent_->slot_ref(index_).element; - } - pointer operator->() { - return &parent_->slot_ref(index_).element; - } - bool operator!=(equal_iterator_gt const &other) const { - return !(*this == other); - } - bool operator==(equal_iterator_gt const &other) const { + reference operator*() { return parent_->slot_ref(index_).element; } + pointer operator->() { return &parent_->slot_ref(index_).element; } + bool operator!=(equal_iterator_gt const& other) const { return !(*this == other); } + bool operator==(equal_iterator_gt const& other) const { return index_ == other.index_ && parent_ == other.parent_; } private: std::size_t index_; - flat_hash_multi_set_gt *parent_; + flat_hash_multi_set_gt* parent_; query_at query_; // Store the query object equals_t equals_; // Store the equals functor }; /** - * @brief Returns an iterator range of all elements matching the given query. - * - * Technically, the second iterator points to the first empty slot after a - * range of equal values and non-equal values with similar hashes. + * @brief Returns an iterator range of all elements matching the given query. + * + * Technically, the second iterator points to the first empty slot after a + * range of equal values and non-equal values with similar hashes. */ template std::pair, equal_iterator_gt> - equal_range(query_at const &query) const noexcept { + equal_range(query_at const& query) const noexcept { equals_t equals; - auto this_ptr = const_cast(this); + auto this_ptr = const_cast(this); auto end = equal_iterator_gt(capacity_slots_, this_ptr, query, equals); if (!capacity_slots_) return {end, end}; @@ -2272,8 +2001,7 @@ class flat_hash_multi_set_gt { equal_iterator_gt(first_empty_index, this_ptr, query, equals)}; } - template - bool pop_first(similar_at &&query, element_t &popped_value) noexcept { + template bool pop_first(similar_at&& query, element_t& popped_value) noexcept { if (!capacity_slots_) return false; @@ -2307,8 +2035,7 @@ class flat_hash_multi_set_gt { return false; // No match found } - template - std::size_t erase(similar_at &&query) noexcept { + template std::size_t erase(similar_at&& query) noexcept { if (!capacity_slots_) return 0; @@ -2342,8 +2069,7 @@ class flat_hash_multi_set_gt { return count; // Return the number of elements removed } - template - element_t const *find(similar_at &&query) const noexcept { + template element_t const* find(similar_at&& query) const noexcept { if (!capacity_slots_) return nullptr; @@ -2372,15 +2098,12 @@ class flat_hash_multi_set_gt { return nullptr; // No match found } - element_t const *end() const noexcept { - return nullptr; - } + element_t const* end() const noexcept { return nullptr; } - template - void for_each(func_at &&func) const { + template void for_each(func_at&& func) const { for (std::size_t bucket_index = 0; bucket_index < buckets_; ++bucket_index) { auto bucket_pointer = data_ + bytes_per_bucket() * bucket_index; - bucket_header_t &header = *reinterpret_cast(bucket_pointer); + bucket_header_t& header = *reinterpret_cast(bucket_pointer); std::uint64_t populated = header.populated; std::uint64_t deleted = header.deleted; @@ -2391,15 +2114,14 @@ class flat_hash_multi_set_gt { // Check if the slot is populated and not deleted if ((populated & ~deleted) & mask) { auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; - element_t const &element = *reinterpret_cast(slot_pointer); + element_t const& element = *reinterpret_cast(slot_pointer); func(element); } } } } - template - std::size_t count(similar_at &&query) const noexcept { + template std::size_t count(similar_at&& query) const noexcept { if (!capacity_slots_) return 0; @@ -2429,8 +2151,7 @@ class flat_hash_multi_set_gt { return count; } - template - bool contains(similar_at &&query) const noexcept { + template bool contains(similar_at&& query) const noexcept { if (!capacity_slots_) return false; @@ -2463,7 +2184,7 @@ class flat_hash_multi_set_gt { throw std::bad_alloc(); } - bool try_emplace(element_t const &element) noexcept { + bool try_emplace(element_t const& element) noexcept { // Check if we need to resize if (populated_slots_ * 3u >= capacity_slots_ * 2u) if (!try_reserve(populated_slots_ + 1))