diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d48d305..7d27b72 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -17,11 +17,15 @@ jobs: matrix: os: [ubuntu-latest] feature: - - default - shared_mem - large_data - - shared_mem,large_data + - quantized + - quantized,qg_optim + - large_data,shared_mem + - large_data,quantized - static + - static,quantized + - static,quantized,qg_optim - static,shared_mem,large_data steps: - uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index 90e2f1d..52d02c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ngt" -version = "0.4.5" +version = "0.5.0" authors = ["Romain Leroux "] edition = "2021" description = "Rust wrappers for NGT nearest neighbor search." @@ -11,17 +11,23 @@ license = "Apache-2.0" readme = "README.md" [dependencies] -ngt-sys = { path = "ngt-sys", version = "1.14.8-static" } -num_enum = "0.5" -openmp-sys = { version="1.2.3", features=["static"] } +half = "2" +ngt-sys = { path = "ngt-sys", version = "2.1.2" } +num_enum = "0.7" scopeguard = "1" [dev-dependencies] +rand = "0.8" rayon = "1" tempfile = "3" [features] -default = [] static = ["ngt-sys/static"] shared_mem = ["ngt-sys/shared_mem"] large_data = ["ngt-sys/large_data"] +quantized = ["ngt-sys/quantized"] +qg_optim = ["quantized", "ngt-sys/qg_optim"] + +[package.metadata.docs.rs] +features = ["quantized"] +rustdoc-args = ["--cfg", "docsrs"] diff --git a/README.md b/README.md index 31e9cc7..3503d44 100644 --- a/README.md +++ b/README.md @@ -1,79 +1,55 @@ -# ngt-rs   [![Latest Version]][crates.io] [![Latest Doc]][docs.rs] +# ngt-rs -[Latest Version]: https://img.shields.io/crates/v/ngt.svg -[crates.io]: https://crates.io/crates/ngt -[Latest Doc]: https://docs.rs/ngt/badge.svg -[docs.rs]: https://docs.rs/ngt +[![crate]][crate-ngt] [![doc]][doc-ngt] -Rust wrappers for [NGT][], which provides high-speed approximate nearest neighbor -searches against a large volume of data. - -Building NGT requires `CMake`. By default `ngt-rs` will be built dynamically, which -means that you'll need to make the build artifact `libngt.so` available to your final -binary. You'll also need to have `OpenMP` installed on the system where it will run. If -you want to build `ngt-rs` statically, then use the `static` Cargo feature, note that in -this case `OpenMP` will be disabled when building NGT. - -Furthermore, NGT's shared memory and large dataset features are available through Cargo -features `shared_mem` and `large_data` respectively. - -## Usage - -Defining the properties of a new index: +[crate]: https://img.shields.io/crates/v/ngt.svg +[crate-ngt]: https://crates.io/crates/ngt +[doc]: https://docs.rs/ngt/badge.svg +[doc-ngt]: https://docs.rs/ngt -```rust -use ngt::{Properties, DistanceType, ObjectType}; - -// Defaut properties with vectors of dimension 3 -let prop = Properties::dimension(3)?; - -// Or customize values (here are the defaults) -let prop = Properties::dimension(3)? - .creation_edge_size(10)? - .search_edge_size(40)? - .object_type(ObjectType::Float)? - .distance_type(DistanceType::L2)?; -``` - -Creating/Opening an index and using it: - -```rust -use ngt::{Index, Properties, EPSILON}; +Rust wrappers for [NGT][], which provides high-speed approximate nearest neighbor +searches against a large volume of data in high dimensional vector data space (several +ten to several thousand dimensions). The vector data can be `f32`, `u8`, or [f16][]. -// Create a new index -let prop = Properties::dimension(3)?; -let index = Index::create("target/path/to/index/dir", prop)?; +This crate provides the following indexes: +* [`NgtIndex`][index-ngt]: Graph and tree based index[^1] +* [`QgIndex`][index-qg]: Quantized graph based index[^2] +* [`QbgIndex`][index-qbg]: Quantized blob graph based index -// Open an existing index -let mut index = Index::open("target/path/to/index/dir")?; +Both quantized indexes are available through the `quantized` Cargo feature. Note that +they rely on `BLAS` and `LAPACK` which thus have to be installed locally. Furthermore, +`QgIndex` performances can be [improved][qg-optim] by using the `qg_optim` Cargo +feature. -// Insert two vectors and get their id -let vec1 = vec![1.0, 2.0, 3.0]; -let vec2 = vec![4.0, 5.0, 6.0]; -let id1 = index.insert(vec1)?; -let id2 = index.insert(vec2)?; +The `NgtIndex` default implementation is an ANNG. It can be optimized[^3] or converted +to an ONNG through the [`optim`][ngt-optim] module. -// Actually build the index (not yet persisted on disk) -// This is required in order to be able to search vectors -index.build(2)?; +By default `ngt-rs` will be built dynamically, which requires `CMake` to build NGT. This +means that you'll have to make the build artifact `libngt.so` available to your final +binary (see an example in the [CI][ngt-ci]). However the `static` feature will build and +link NGT statically. Note that `OpenMP` will also be linked statically. If the +`quantized` feature is used, then `BLAS` and `LAPACK` libraries will also be linked +statically. -// Perform a vector search (with 1 result) -let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; -assert_eq!(res[0].id, id1); -assert_eq!(index.get_vec(id1)?, vec![1.0, 2.0, 3.0]); +NGT's [shared memory][ngt-sharedmem] and [large dataset][ngt-largedata] features are +available through the Cargo features `shared_mem` and `large_data` respectively. -// Remove a vector and check that it is not present anymore -index.remove(id1)?; -let res = index.get_vec(id1); -assert!(matches!(res, Result::Err(_))); +[^1]: [Graph and tree based method explanation][ngt-desc] -// Verify that now our search result is different -let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; -assert_eq!(res[0].id, id2); -assert_eq!(index.get_vec(id2)?, vec![4.0, 5.0, 6.0]); +[^2]: [Quantized graph based method explanation][qg-desc] -// Persist index on disk -index.persist()?; -``` +[^3]: [NGT index optimizations in Python][ngt-optim-py] [ngt]: https://github.com/yahoojapan/NGT +[ngt-desc]: https://opensource.com/article/19/10/ngt-open-source-library +[ngt-sharedmem]: https://github.com/yahoojapan/NGT#shared-memory-use +[ngt-largedata]: https://github.com/yahoojapan/NGT#large-scale-data-use +[ngt-ci]: https://github.com/lerouxrgd/ngt-rs/blob/master/.github/workflows/ci.yaml +[ngt-optim]: https://docs.rs/ngt/latest/ngt/optim/index.html +[ngt-optim-py]: https://github.com/yahoojapan/NGT/wiki/Optimization-Examples-Using-Python +[qg-desc]: https://medium.com/@masajiro.iwasaki/fusion-of-graph-based-indexing-and-product-quantization-for-ann-search-7d1f0336d0d0 +[qg-optim]: https://github.com/yahoojapan/NGT#build-parameters-1 +[f16]: https://docs.rs/half/latest/half/struct.f16.html +[index-ngt]: https://docs.rs/ngt/latest/ngt/#usage +[index-qg]: https://docs.rs/ngt/latest/ngt/qg/ +[index-qbg]: https://docs.rs/ngt/latest/ngt/qgb/ diff --git a/ngt-sys/Cargo.toml b/ngt-sys/Cargo.toml index 694c094..3a628a5 100644 --- a/ngt-sys/Cargo.toml +++ b/ngt-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ngt-sys" -version = "1.14.8-static" +version = "2.1.2" authors = ["Romain Leroux "] edition = "2021" links = "ngt" @@ -18,4 +18,6 @@ cpp_build = { version = "0.5", optional = true } [features] static = ["dep:cpp_build"] shared_mem = [] -large_data = [] \ No newline at end of file +large_data = [] +quantized = [] +qg_optim = [] diff --git a/ngt-sys/NGT b/ngt-sys/NGT index b843128..eaee806 160000 --- a/ngt-sys/NGT +++ b/ngt-sys/NGT @@ -1 +1 @@ -Subproject commit b84312870d496d0f4e3ead449bc6424545d9f896 +Subproject commit eaee8063c8bc5145670985f89341f94f9f6d5349 diff --git a/ngt-sys/build.rs b/ngt-sys/build.rs index 8e83626..0387d8f 100644 --- a/ngt-sys/build.rs +++ b/ngt-sys/build.rs @@ -5,35 +5,54 @@ fn main() { let out_dir = env::var("OUT_DIR").unwrap(); let mut config = cmake::Config::new("NGT"); - if env::var("CARGO_FEATURE_SHARED_MEM").is_ok() { config.define("NGT_SHARED_MEMORY_ALLOCATOR", "ON"); } - if env::var("CARGO_FEATURE_LARGE_DATA").is_ok() { config.define("NGT_LARGE_DATASET", "ON"); } - + if env::var("CARGO_FEATURE_QUANTIZED").is_err() { + config.define("NGT_QBG_DISABLED", "ON"); + } else { + config.define("CMAKE_BUILD_TYPE", "Release"); + if env::var("CARGO_FEATURE_QG_OPTIM").is_ok() { + config.define("NGTQG_NO_ROTATION", "ON"); + config.define("NGTQG_ZERO_GLOBAL", "ON"); + } + } let dst = config.build(); - #[cfg(feature = "static")] - cpp_build::Config::new() - .include(format!("{}/lib", out_dir)) - .build("src/lib.rs"); - println!("cargo:rustc-link-search=native={}/lib", dst.display()); - #[cfg(feature = "static")] - println!("cargo:rustc-link-lib=static=ngt"); #[cfg(not(feature = "static"))] - println!("cargo:rustc-link-lib=dylib=ngt"); + { + println!("cargo:rustc-link-lib=dylib=ngt"); + } + #[cfg(feature = "static")] + { + cpp_build::Config::new() + .include(format!("{}/lib", out_dir)) + .build("src/lib.rs"); + println!("cargo:rustc-link-lib=static=ngt"); + println!("cargo:rustc-link-lib=gomp"); + + if env::var("CARGO_FEATURE_QUANTIZED").is_ok() { + println!("cargo:rustc-link-lib=blas"); + println!("cargo:rustc-link-lib=lapack"); + } + } + + let capi_header = if cfg!(feature = "quantized") { + format!("{}/include/NGT/NGTQ/Capi.h", dst.display()) + } else { + format!("{}/include/NGT/Capi.h", dst.display()) + }; + let out_path = PathBuf::from(out_dir); let bindings = bindgen::Builder::default() .clang_arg(format!("-I{}/include", dst.display())) - .header(format!("{}/include/NGT/NGTQ/Capi.h", dst.display())) + .header(capi_header) .generate() .expect("Unable to generate bindings"); - - let out_path = PathBuf::from(out_dir); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings"); diff --git a/src/error.rs b/src/error.rs index 33f83ee..b648d1f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,8 +3,6 @@ use std::fmt; use ngt_sys as sys; -use crate::properties::{DistanceType, ObjectType}; - pub type Result = std::result::Result; #[derive(Debug)] @@ -25,6 +23,12 @@ pub(crate) fn make_err(err: sys::NGTError) -> Error { Error(err_msg) } +impl From for Error { + fn from(err: String) -> Self { + Self(err) + } +} + impl From for Error { fn from(source: std::io::Error) -> Self { Self(source.to_string()) @@ -43,14 +47,48 @@ impl From for Error { } } -impl From> for Error { - fn from(source: num_enum::TryFromPrimitiveError) -> Self { +impl From for Error { + fn from(source: std::ffi::IntoStringError) -> Self { + Self(source.to_string()) + } +} + +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { + Self(source.to_string()) + } +} + +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { + Self(source.to_string()) + } +} + +#[cfg(feature = "quantized")] +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { + Self(source.to_string()) + } +} + +#[cfg(feature = "quantized")] +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { + Self(source.to_string()) + } +} + +#[cfg(feature = "quantized")] +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { Self(source.to_string()) } } -impl From> for Error { - fn from(source: num_enum::TryFromPrimitiveError) -> Self { +#[cfg(feature = "quantized")] +impl From> for Error { + fn from(source: num_enum::TryFromPrimitiveError) -> Self { Self(source.to_string()) } } diff --git a/src/lib.rs b/src/lib.rs index 5b5e46a..d3f9b37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,49 +1,41 @@ -//! Rust wrappers for [NGT][], which provides high-speed approximate nearest neighbor -//! searches against a large volume of data. +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![doc = include_str!("../README.md")] //! -//! Building NGT requires `CMake`. By default `ngt-rs` will be built dynamically, which -//! means that you'll need to make the build artifact `libngt.so` available to your final -//! binary. You'll also need to have `OpenMP` installed on the system where it will run. If -//! you want to build `ngt-rs` statically, then use the `static` Cargo feature, note that in -//! this case `OpenMP` will be disabled when building NGT. +//! # Usage //! -//! Furthermore, NGT's shared memory and large dataset features are available through Cargo -//! features `shared_mem` and `large_data` respectively. +//! Graph and tree based index (NGT Index) //! -//! ## Usage -//! -//! Defining the properties of a new index: +//! ## Defining the properties of a new NGT index: //! //! ```rust //! # fn main() -> Result<(), ngt::Error> { -//! use ngt::{Properties, DistanceType, ObjectType}; +//! use ngt::{NgtProperties, NgtDistance}; //! //! // Defaut properties with vectors of dimension 3 -//! let prop = Properties::dimension(3)?; +//! let prop = NgtProperties::::dimension(3)?; //! //! // Or customize values (here are the defaults) -//! let prop = Properties::dimension(3)? +//! let prop = NgtProperties::::dimension(3)? //! .creation_edge_size(10)? //! .search_edge_size(40)? -//! .object_type(ObjectType::Float)? -//! .distance_type(DistanceType::L2)?; +//! .distance_type(NgtDistance::L2)?; //! //! # Ok(()) //! # } //! ``` //! -//! Creating/Opening an index and using it: +//! ## Creating/Opening a NGT index and using it: //! //! ```rust //! # fn main() -> Result<(), ngt::Error> { -//! use ngt::{Index, Properties, EPSILON}; +//! use ngt::{NgtIndex, NgtProperties}; //! //! // Create a new index -//! let prop = Properties::dimension(3)?; -//! let index = Index::create("target/path/to/index/dir", prop)?; +//! let prop = NgtProperties::dimension(3)?; +//! let index: NgtIndex = NgtIndex::create("target/path/to/ngt_index/dir", prop)?; //! //! // Open an existing index -//! let mut index = Index::open("target/path/to/index/dir")?; +//! let mut index = NgtIndex::open("target/path/to/ngt_index/dir")?; //! //! // Insert two vectors and get their id //! let vec1 = vec![1.0, 2.0, 3.0]; @@ -51,46 +43,54 @@ //! let id1 = index.insert(vec1)?; //! let id2 = index.insert(vec2)?; //! -//! // Actually build the index (not yet persisted on disk) +//! // Build the index in RAM (not yet persisted on disk) //! // This is required in order to be able to search vectors //! index.build(2)?; //! //! // Perform a vector search (with 1 result) -//! let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; +//! let res = index.search(&vec![1.1, 2.1, 3.1], 1, ngt::EPSILON)?; //! assert_eq!(res[0].id, id1); //! assert_eq!(index.get_vec(id1)?, vec![1.0, 2.0, 3.0]); //! //! // Remove a vector and check that it is not present anymore //! index.remove(id1)?; //! let res = index.get_vec(id1); -//! assert!(matches!(res, Result::Err(_))); +//! assert!(res.is_err()); //! //! // Verify that now our search result is different -//! let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; +//! let res = index.search(&vec![1.1, 2.1, 3.1], 1, ngt::EPSILON)?; //! assert_eq!(res[0].id, id2); //! assert_eq!(index.get_vec(id2)?, vec![4.0, 5.0, 6.0]); //! //! // Persist index on disk //! index.persist()?; //! -//! # std::fs::remove_dir_all("target/path/to/index/dir").unwrap(); +//! # std::fs::remove_dir_all("target/path/to/ngt_index/dir").unwrap(); //! # Ok(()) //! # } //! ``` -//! -//! [ngt]: https://github.com/yahoojapan/NGT -// See: https://gitlab.com/kornelski/openmp-rs#1-adding-rust-dependency -extern crate openmp_sys; +#[cfg(all(feature = "quantized", feature = "shared_mem"))] +compile_error!(r#"only one of ["quantized", "shared_mem"] can be enabled"#); mod error; -mod index; -pub mod optim; -mod properties; +mod ngt; +#[cfg(feature = "quantized")] +pub mod qbg; +#[cfg(feature = "quantized")] +pub mod qg; + +pub type VecId = u32; + +#[derive(Debug, Clone, PartialEq)] +pub struct SearchResult { + pub id: VecId, + pub distance: f32, +} + +pub const EPSILON: f32 = 0.1; -pub use crate::error::Error; -pub use crate::index::{Index, SearchResult, VecId, EPSILON}; -pub use crate::properties::{DistanceType, ObjectType, Properties}; +pub use crate::error::{Error, Result}; +pub use crate::ngt::{optim, NgtDistance, NgtIndex, NgtObject, NgtProperties}; -#[cfg(not(feature = "shared_mem"))] -pub use crate::index::{QGIndex, QGQuantizationParams, QGQuery}; +pub use half; diff --git a/src/index.rs b/src/ngt/index.rs similarity index 57% rename from src/index.rs rename to src/ngt/index.rs index 92655d5..1bbc830 100644 --- a/src/index.rs +++ b/src/ngt/index.rs @@ -9,34 +9,28 @@ use std::ptr; use ngt_sys as sys; use scopeguard::defer; +use super::{NgtObject, NgtObjectType, NgtProperties}; use crate::error::{make_err, Error, Result}; -use crate::properties::{ObjectType, Properties}; - -pub const EPSILON: f32 = 0.1; - -pub type VecId = u32; - -#[derive(Debug, Clone, PartialEq)] -pub struct SearchResult { - pub id: VecId, - pub distance: f32, -} +use crate::{SearchResult, VecId}; #[derive(Debug)] -pub struct Index { +pub struct NgtIndex { pub(crate) path: CString, - pub(crate) prop: Properties, + pub(crate) prop: NgtProperties, pub(crate) index: sys::NGTIndex, ospace: sys::NGTObjectSpace, ebuf: sys::NGTError, } -unsafe impl Send for Index {} -unsafe impl Sync for Index {} +unsafe impl Send for NgtIndex {} +unsafe impl Sync for NgtIndex {} -impl Index { - /// Creates an empty ANNG index with the given [`Properties`](). - pub fn create>(path: P, prop: Properties) -> Result { +impl NgtIndex +where + T: NgtObjectType, +{ + /// Creates an empty ANNG index with the given [`NgtProperties`](). + pub fn create>(path: P, prop: NgtProperties) -> Result { if cfg!(feature = "shared_mem") && path.as_ref().exists() { Err(Error(format!("Path {:?} already exists", path.as_ref())))? } @@ -67,7 +61,7 @@ impl Index { Err(make_err(ebuf))? } - Ok(Index { + Ok(NgtIndex { path, prop, index, @@ -99,9 +93,9 @@ impl Index { Err(make_err(ebuf))? } - let prop = Properties::from(index)?; + let prop = NgtProperties::from(index)?; - Ok(Index { + Ok(NgtIndex { path, prop, index, @@ -113,8 +107,8 @@ impl Index { /// Search the nearest vectors to the specified query vector. /// - /// **The index must have been [`built`](Index::build) beforehand**. - pub fn search(&self, vec: &[f64], res_size: u64, epsilon: f32) -> Result> { + /// **The index must have been [`built`](NgtIndex::build) beforehand**. + pub fn search(&self, vec: &[f32], res_size: u64, epsilon: f32) -> Result> { unsafe { let results = sys::ngt_create_empty_results(self.ebuf); if results.is_null() { @@ -122,9 +116,9 @@ impl Index { } defer! { sys::ngt_destroy_results(results); } - if !sys::ngt_search_index( + if !sys::ngt_search_index_as_float( self.index, - vec.as_ptr() as *mut f64, + vec.as_ptr() as *mut f32, self.prop.dimension, res_size, epsilon, @@ -138,7 +132,7 @@ impl Index { let rsize = sys::ngt_get_result_size(results, self.ebuf); let mut ret = Vec::with_capacity(rsize as usize); - for i in 0..rsize as u32 { + for i in 0..rsize { let d = sys::ngt_get_result(results, i, self.ebuf); if d.id == 0 && d.distance == 0.0 { Err(make_err(self.ebuf))? @@ -156,8 +150,8 @@ impl Index { /// Search linearly the nearest vectors to the specified query vector. /// - /// **The index must have been [`built`](Index::build) beforehand**. - pub fn linear_search(&self, vec: &[f64], res_size: u64) -> Result> { + /// **The index must have been [`built`](NgtIndex::build) beforehand**. + pub fn linear_search(&self, vec: &[f32], res_size: u64) -> Result> { unsafe { let results = sys::ngt_create_empty_results(self.ebuf); if results.is_null() { @@ -165,9 +159,9 @@ impl Index { } defer! { sys::ngt_destroy_results(results); } - if !sys::ngt_linear_search_index( + if !sys::ngt_linear_search_index_as_float( self.index, - vec.as_ptr() as *mut f64, + vec.as_ptr() as *mut f32, self.prop.dimension, res_size, results, @@ -179,7 +173,7 @@ impl Index { let rsize = sys::ngt_get_result_size(results, self.ebuf); let mut ret = Vec::with_capacity(rsize as usize); - for i in 0..rsize as u32 { + for i in 0..rsize { let d = sys::ngt_get_result(results, i, self.ebuf); if d.id == 0 && d.distance == 0.0 { Err(make_err(self.ebuf))? @@ -198,30 +192,42 @@ impl Index { /// Insert the specified vector into the index. However note that it is not /// discoverable yet. /// - /// **The method [`build`](Index::build) must be called after inserting vectors**. - pub fn insert>(&mut self, vec: Vec) -> Result { + /// **The method [`build`](NgtIndex::build) must be called after inserting vectors**. + pub fn insert(&mut self, mut vec: Vec) -> Result { unsafe { - let mut vec = vec.into_iter().map(Into::into).collect::>(); - - let id = sys::ngt_insert_index( - self.index, - vec.as_mut_ptr(), - self.prop.dimension as u32, - self.ebuf, - ); + let id = match self.prop.object_type { + NgtObject::Float => sys::ngt_insert_index_as_float( + self.index, + vec.as_mut_ptr() as *mut f32, + self.prop.dimension as u32, + self.ebuf, + ), + NgtObject::Uint8 => sys::ngt_insert_index_as_uint8( + self.index, + vec.as_mut_ptr() as *mut u8, + self.prop.dimension as u32, + self.ebuf, + ), + NgtObject::Float16 => sys::ngt_insert_index_as_float16( + self.index, + vec.as_mut_ptr() as *mut _, + self.prop.dimension as u32, + self.ebuf, + ), + }; if id == 0 { Err(make_err(self.ebuf))? + } else { + Ok(id) } - - Ok(id) } } /// Insert the multiple vectors into the index. However note that they are not /// discoverable yet. /// - /// **The method [`build`](Index::build) must be called after inserting vectors**. - pub fn insert_batch>(&mut self, batch: Vec>) -> Result<()> { + /// **The method [`build`](NgtIndex::build) must be called after inserting vectors**. + pub fn insert_batch(&mut self, batch: Vec>) -> Result<()> { let batch_size = u32::try_from(batch.len())?; if batch_size > 0 { @@ -237,16 +243,39 @@ impl Index { } unsafe { - let mut batch = batch - .into_iter() - .flatten() - .map(|v| v.into() as f32) - .collect::>(); - - if !sys::ngt_batch_append_index(self.index, batch.as_mut_ptr(), batch_size, self.ebuf) { - Err(make_err(self.ebuf))? + let mut batch = batch.into_iter().flatten().collect::>(); + match self.prop.object_type { + NgtObject::Float => { + if !sys::ngt_batch_append_index( + self.index, + batch.as_mut_ptr() as *mut f32, + batch_size, + self.ebuf, + ) { + Err(make_err(self.ebuf))? + } + } + NgtObject::Uint8 => { + if !sys::ngt_batch_append_index_as_uint8( + self.index, + batch.as_mut_ptr() as *mut u8, + batch_size, + self.ebuf, + ) { + Err(make_err(self.ebuf))? + } + } + NgtObject::Float16 => { + if !sys::ngt_batch_append_index_as_float16( + self.index, + batch.as_mut_ptr() as *mut _, + batch_size, + self.ebuf, + ) { + Err(make_err(self.ebuf))? + } + } } - Ok(()) } } @@ -282,209 +311,76 @@ impl Index { } /// Get the specified vector. - pub fn get_vec(&self, id: VecId) -> Result> { + pub fn get_vec(&self, id: VecId) -> Result> { unsafe { - let results = match self.prop.object_type { - ObjectType::Float => { + match self.prop.object_type { + NgtObject::Float => { let results = sys::ngt_get_object_as_float(self.ospace, id, self.ebuf); if results.is_null() { Err(make_err(self.ebuf))? } let results = Vec::from_raw_parts( - results as *mut f32, + results, self.prop.dimension as usize, self.prop.dimension as usize, ); let results = mem::ManuallyDrop::new(results); - results.iter().copied().collect::>() + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) } - ObjectType::Uint8 => { - let results = sys::ngt_get_object_as_integer(self.ospace, id, self.ebuf); + NgtObject::Float16 => { + let results = sys::ngt_get_object(self.ospace, id, self.ebuf); if results.is_null() { Err(make_err(self.ebuf))? } let results = Vec::from_raw_parts( - results as *mut u8, + results as *mut half::f16, self.prop.dimension as usize, self.prop.dimension as usize, ); let results = mem::ManuallyDrop::new(results); - results.iter().map(|byte| *byte as f32).collect::>() + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) } - }; - - Ok(results) - } - } - - /// The number of vectors inserted (but not necessarily indexed). - pub fn nb_inserted(&self) -> u32 { - unsafe { sys::ngt_get_number_of_objects(self.index, self.ebuf) } - } - - /// The number of indexed vectors, available after [`build`](Index::build). - pub fn nb_indexed(&self) -> u32 { - unsafe { sys::ngt_get_number_of_indexed_objects(self.index, self.ebuf) } - } -} - -impl Drop for Index { - fn drop(&mut self) { - if !self.index.is_null() { - unsafe { sys::ngt_close_index(self.index) }; - self.index = ptr::null_mut(); - } - if !self.ebuf.is_null() { - unsafe { sys::ngt_destroy_error_object(self.ebuf) }; - self.ebuf = ptr::null_mut(); - } - } -} - -#[cfg(not(feature = "shared_mem"))] -#[derive(Debug)] -pub struct QGIndex { - pub(crate) prop: Properties, - pub(crate) index: sys::NGTQGIndex, - ebuf: sys::NGTError, -} - -#[cfg(not(feature = "shared_mem"))] -impl QGIndex { - pub fn quantize(index: Index, params: QGQuantizationParams) -> Result { - unsafe { - let ebuf = sys::ngt_create_error_object(); - defer! { sys::ngt_destroy_error_object(ebuf); } - - let path = index.path.clone(); - - drop(index); // Close the index - if !sys::ngtqg_quantize(path.as_ptr(), params.into_raw(), ebuf) { - Err(make_err(ebuf))? - } - - QGIndex::open(path.to_str().unwrap()) - } - } - - /// Open the already existing quantized index at the specified path. - pub fn open>(path: P) -> Result { - if !path.as_ref().exists() { - Err(Error(format!("Path {:?} does not exist", path.as_ref())))? - } - - unsafe { - let ebuf = sys::ngt_create_error_object(); - defer! { sys::ngt_destroy_error_object(ebuf); } - - let path = CString::new(path.as_ref().as_os_str().as_bytes())?; - - let index = sys::ngtqg_open_index(path.as_ptr(), ebuf); - if index.is_null() { - Err(make_err(ebuf))? - } - - let prop = Properties::from(index)?; - - Ok(QGIndex { - prop, - index, - ebuf: sys::ngt_create_error_object(), - }) - } - } - - pub fn search(&self, query: QGQuery) -> Result> { - unsafe { - let results = sys::ngt_create_empty_results(self.ebuf); - if results.is_null() { - Err(make_err(self.ebuf))? - } - defer! { sys::ngt_destroy_results(results); } - - if !sys::ngtqg_search_index(self.index, query.into_raw(), results, self.ebuf) { - Err(make_err(self.ebuf))? - } - - let rsize = sys::ngt_get_result_size(results, self.ebuf); - let mut ret = Vec::with_capacity(rsize as usize); - - for i in 0..rsize as u32 { - let d = sys::ngt_get_result(results, i, self.ebuf); - if d.id == 0 && d.distance == 0.0 { - Err(make_err(self.ebuf))? - } else { - ret.push(SearchResult { - id: d.id, - distance: d.distance, - }); - } - } - - Ok(ret) - } - } - - /// Get the specified vector. - pub fn get_vec(&self, id: VecId) -> Result> { - unsafe { - let results = match self.prop.object_type { - ObjectType::Float => { - let ospace = sys::ngt_get_object_space(self.index, self.ebuf); - if ospace.is_null() { - Err(make_err(self.ebuf))? - } - - let results = sys::ngt_get_object_as_float(ospace, id, self.ebuf); + NgtObject::Uint8 => { + let results = sys::ngt_get_object_as_integer(self.ospace, id, self.ebuf); if results.is_null() { Err(make_err(self.ebuf))? } let results = Vec::from_raw_parts( - results as *mut f32, + results, self.prop.dimension as usize, self.prop.dimension as usize, ); let results = mem::ManuallyDrop::new(results); - results.iter().copied().collect::>() + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) } - ObjectType::Uint8 => { - let ospace = sys::ngt_get_object_space(self.index, self.ebuf); - if ospace.is_null() { - Err(make_err(self.ebuf))? - } - - let results = sys::ngt_get_object_as_integer(ospace, id, self.ebuf); - if results.is_null() { - Err(make_err(self.ebuf))? - } - - let results = Vec::from_raw_parts( - results as *mut u8, - self.prop.dimension as usize, - self.prop.dimension as usize, - ); - let results = mem::ManuallyDrop::new(results); + } + } + } - results.iter().map(|byte| *byte as f32).collect::>() - } - }; + /// The number of vectors inserted (but not necessarily indexed). + pub fn nb_inserted(&self) -> u32 { + unsafe { sys::ngt_get_number_of_objects(self.index, self.ebuf) } + } - Ok(results) - } + /// The number of indexed vectors, available after [`build`](NgtIndex::build). + pub fn nb_indexed(&self) -> u32 { + unsafe { sys::ngt_get_number_of_indexed_objects(self.index, self.ebuf) } } } -#[cfg(not(feature = "shared_mem"))] -impl Drop for QGIndex { +impl Drop for NgtIndex { fn drop(&mut self) { if !self.index.is_null() { - unsafe { sys::ngtqg_close_index(self.index) }; + unsafe { sys::ngt_close_index(self.index) }; self.index = ptr::null_mut(); } if !self.ebuf.is_null() { @@ -494,99 +390,21 @@ impl Drop for QGIndex { } } -#[cfg(not(feature = "shared_mem"))] -#[derive(Debug, Clone, PartialEq)] -pub struct QGQuantizationParams { - pub dimension_of_subvector: f32, - pub max_number_of_edges: u64, -} - -#[cfg(not(feature = "shared_mem"))] -impl Default for QGQuantizationParams { - fn default() -> Self { - Self { - dimension_of_subvector: 0.0, - max_number_of_edges: 128, - } - } -} - -#[cfg(not(feature = "shared_mem"))] -impl QGQuantizationParams { - unsafe fn into_raw(self) -> sys::NGTQGQuantizationParameters { - sys::NGTQGQuantizationParameters { - dimension_of_subvector: self.dimension_of_subvector, - max_number_of_edges: self.max_number_of_edges, - } - } -} - -#[cfg(not(feature = "shared_mem"))] -#[derive(Debug, Clone, PartialEq)] -pub struct QGQuery<'a> { - query: &'a [f32], - pub size: u64, - pub epsilon: f32, - pub result_expansion: f32, - pub radius: f32, -} - -#[cfg(not(feature = "shared_mem"))] -impl<'a> QGQuery<'a> { - pub fn new(query: &'a [f32]) -> Self { - Self { - query, - size: 20, - epsilon: 0.03, - result_expansion: 3.0, - radius: f32::MAX, - } - } - - pub fn size(mut self, size: u64) -> Self { - self.size = size; - self - } - - pub fn epsilon(mut self, epsilon: f32) -> Self { - self.epsilon = epsilon; - self - } - - pub fn result_expansion(mut self, result_expansion: f32) -> Self { - self.result_expansion = result_expansion; - self - } - - pub fn radius(mut self, radius: f32) -> Self { - self.radius = radius; - self - } - - unsafe fn into_raw(self) -> sys::NGTQGQuery { - sys::NGTQGQuery { - query: self.query.as_ptr() as *mut f32, - size: self.size, - epsilon: self.epsilon, - result_expansion: self.result_expansion, - radius: self.radius, - } - } -} - #[cfg(test)] mod tests { use std::error::Error as StdError; use std::iter; use std::result::Result as StdResult; + use half::f16; use rayon::prelude::*; use tempfile::tempdir; use super::*; + use crate::EPSILON; #[test] - fn test_basics() -> StdResult<(), Box> { + fn test_ngt_f32_basics() -> StdResult<(), Box> { // Get a temporary directory to store the index let dir = tempdir()?; if cfg!(feature = "shared_mem") { @@ -594,8 +412,8 @@ mod tests { } // Create an index for vectors of dimension 3 - let prop = Properties::dimension(3)?; - let mut index = Index::create(dir.path(), prop)?; + let prop = NgtProperties::::dimension(3)?; + let mut index = NgtIndex::create(dir.path(), prop)?; // Insert two vectors and get their id let vec1 = vec![1.0, 2.0, 3.0]; @@ -635,7 +453,7 @@ mod tests { // Persist index on disk, and open it again index.persist()?; - index = Index::open(dir.path())?; + index = NgtIndex::open(dir.path())?; assert!(index.nb_inserted() == 1); assert!(index.nb_indexed() == 1); @@ -653,7 +471,7 @@ mod tests { } #[test] - fn test_batch() -> StdResult<(), Box> { + fn test_ngt_batch() -> StdResult<(), Box> { // Get a temporary directory to store the index let dir = tempdir()?; if cfg!(feature = "shared_mem") { @@ -661,8 +479,8 @@ mod tests { } // Create an index for vectors of dimension 3 - let prop = Properties::dimension(3)?; - let mut index = Index::create(dir.path(), prop)?; + let prop = NgtProperties::::dimension(3)?; + let mut index = NgtIndex::create(dir.path(), prop)?; // Batch insert 2 vectors, build and persist the index index.insert_batch(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]])?; @@ -678,7 +496,33 @@ mod tests { } #[test] - fn test_multithreaded() -> StdResult<(), Box> { + fn test_ngt_u8() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + if cfg!(feature = "shared_mem") { + std::fs::remove_dir(dir.path())?; + } + + // Create an index for vectors of dimension 3 + let prop = NgtProperties::::dimension(3)?; + let mut index = NgtIndex::create(dir.path(), prop)?; + + // Insert 3 vectors, build and persist the index + index.insert_batch(vec![vec![1, 2, 3], vec![4, 5, 6]])?; + index.insert(vec![7, 8, 9])?; + index.build(2)?; + index.persist()?; + + // Verify that the index was built correctly with a vector search + let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; + assert_eq!(1, res[0].id); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_ngt_f16() -> StdResult<(), Box> { // Get a temporary directory to store the index let dir = tempdir()?; if cfg!(feature = "shared_mem") { @@ -686,8 +530,37 @@ mod tests { } // Create an index for vectors of dimension 3 - let prop = Properties::dimension(3)?; - let mut index = Index::create(dir.path(), prop)?; + let prop = NgtProperties::::dimension(3)?; + let mut index = NgtIndex::create(dir.path(), prop)?; + + // Insert 3 vectors, build and persist the index + index.insert_batch(vec![ + vec![1.0, 2.0, 3.0].into_iter().map(f16::from_f32).collect(), + vec![4.0, 5.0, 6.0].into_iter().map(f16::from_f32).collect(), + ])?; + index.insert(vec![7.0, 8.0, 9.0].into_iter().map(f16::from_f32).collect())?; + index.build(2)?; + index.persist()?; + + // Verify that the index was built correctly with a vector search + let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?; + assert_eq!(1, res[0].id); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_ngt_multithreaded() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + if cfg!(feature = "shared_mem") { + std::fs::remove_dir(dir.path())?; + } + + // Create an index for vectors of dimension 3 + let prop = NgtProperties::::dimension(3)?; + let mut index = NgtIndex::create(dir.path(), prop)?; let vecs = vec![ vec![1.0, 2.0, 3.0], @@ -719,41 +592,4 @@ mod tests { dir.close()?; Ok(()) } - - #[cfg(not(feature = "shared_mem"))] - #[test] - fn test_quantize() -> StdResult<(), Box> { - // Get a temporary directory to store the index - let dir = tempdir()?; - if cfg!(feature = "shared_mem") { - std::fs::remove_dir(dir.path())?; - } - - // Create an index for vectors of dimension 3 - let prop = Properties::dimension(3)?; - let mut index = Index::create(dir.path(), prop)?; - - // Insert two vectors and get their id - let vec1 = vec![1.0, 2.0, 3.0]; - let vec2 = vec![4.0, 5.0, 6.0]; - let id1 = index.insert(vec1.clone())?; - let _id2 = index.insert(vec2.clone())?; - - // Build and persist the index - index.build(1)?; - index.persist()?; - - let params = QGQuantizationParams::default(); - let index = QGIndex::quantize(index, params)?; - - // Perform a vector search (with 1 result) - let vec = vec![1.1, 2.1, 3.1]; - let query = QGQuery::new(&vec).size(2); - let res = index.search(query)?; - assert_eq!(id1, res[0].id); - assert_eq!(vec1, index.get_vec(id1)?); - - dir.close()?; - Ok(()) - } } diff --git a/src/ngt/mod.rs b/src/ngt/mod.rs new file mode 100644 index 0000000..6403bd4 --- /dev/null +++ b/src/ngt/mod.rs @@ -0,0 +1,6 @@ +mod index; +pub mod optim; +mod properties; + +pub use self::index::NgtIndex; +pub use self::properties::{NgtDistance, NgtObject, NgtObjectType, NgtProperties}; diff --git a/src/optim.rs b/src/ngt/optim.rs similarity index 79% rename from src/optim.rs rename to src/ngt/optim.rs index ddde83b..4aa1156 100644 --- a/src/optim.rs +++ b/src/ngt/optim.rs @@ -1,3 +1,7 @@ +#![cfg_attr(feature = "shared_mem", allow(unused_imports))] + +//! Functions aimed at optimizing [`NgtIndex`](NgtIndex) + use std::ffi::CString; use std::os::unix::ffi::OsStrExt; use std::path::Path; @@ -6,19 +10,20 @@ use std::ptr; use ngt_sys as sys; use scopeguard::defer; +use super::NgtObjectType; use crate::error::{make_err, Result}; -use crate::index::Index; +use crate::ngt::index::NgtIndex; /// Optimizes the number of initial edges of an ANNG index. /// /// The default number of initial edges for each node in a default graph (ANNG) is a /// fixed value 10. To optimize this number, follow these steps: -/// 1. [`insert`](Index::insert) vectors in the ANNG index at `index_path`, don't -/// [`build`](Index::build) the index yet. -/// 2. When all vectors are inserted, [`persist`](Index::persist) the index. +/// 1. [`insert`](NgtIndex::insert) vectors in the ANNG index at `index_path`, don't +/// [`build`](NgtIndex::build) the index yet. +/// 2. When all vectors are inserted, [`persist`](NgtIndex::persist) the index. /// 3. Call this function with the same `index_path`. -/// 4. [`open`](Index::open) the index at `index_path` again, and now -/// [`build`](Index::build) it. +/// 4. [`open`](NgtIndex::open) the index at `index_path` again, and now +/// [`build`](NgtIndex::build) it. #[cfg(not(feature = "shared_mem"))] pub fn optimize_anng_edges_number>( index_path: P, @@ -42,10 +47,14 @@ pub fn optimize_anng_edges_number>( /// /// Optimizes the search parameters about the explored edges and memory prefetch for the /// existing indexes. Does not modify the index data structure. -pub fn optimize_anng_search_parameters>(index_path: P) -> Result<()> { +pub fn optimize_anng_search_parameters(index_path: P) -> Result<()> +where + T: NgtObjectType, + P: AsRef, +{ let mut optimizer = GraphOptimizer::new(GraphOptimParams::default())?; optimizer.set_processing_modes(true, true, true)?; - optimizer.adjust_search_coefficients(index_path)?; + optimizer.adjust_search_coefficients::(index_path)?; Ok(()) } @@ -53,9 +62,12 @@ pub fn optimize_anng_search_parameters>(index_path: P) -> Result< /// /// Improves accuracy of neighboring nodes for each node by searching with each /// node. Note that refinement takes a long processing time. An ANNG index can be -/// refined only after it has been [`built`](Index::build). +/// refined only after it has been [`built`](NgtIndex::build). #[cfg(not(feature = "shared_mem"))] -pub fn refine_anng(index: &mut Index, params: AnngRefineParams) -> Result<()> { +pub fn refine_anng( + index: &mut NgtIndex, + params: AnngRefineParams, +) -> Result<()> { unsafe { let ebuf = sys::ngt_create_error_object(); defer! { sys::ngt_destroy_error_object(ebuf); } @@ -84,19 +96,23 @@ pub fn refine_anng(index: &mut Index, params: AnngRefineParams) -> Result<()> { /// [`optimize_anng_edges_number`](optimize_anng_edges_number). /// /// If more performance is needed, a larger `creation_edge_size` can be set through -/// [`Properties`](crate::Properties::creation_edge_size) at ANNG index -/// [`create`](Index::create) time. +/// [`Properties`](crate::NgtProperties::creation_edge_size) at ANNG index +/// [`create`](NgtIndex::create) time. /// /// Important [`GraphOptimParams`](GraphOptimParams) parameters are `nb_outgoing` edges /// and `nb_incoming` edges. The latter can be set to an even higher number than the /// `creation_edge_size` of the original ANNG. -pub fn convert_anng_to_onng>( +pub fn convert_anng_to_onng( index_anng_in: P, index_onng_out: P, params: GraphOptimParams, -) -> Result<()> { +) -> Result<()> +where + T: NgtObjectType, + P: AsRef, +{ let mut optimizer = GraphOptimizer::new(params)?; - optimizer.convert_anng_to_onng(index_anng_in, index_onng_out)?; + optimizer.convert_anng_to_onng::(index_anng_in, index_onng_out)?; Ok(()) } @@ -253,8 +269,12 @@ impl GraphOptimizer { } /// Optimize for the search parameters of an ANNG. - fn adjust_search_coefficients>(&mut self, index_path: P) -> Result<()> { - let _ = Index::open(&index_path)?; + fn adjust_search_coefficients(&mut self, index_path: P) -> Result<()> + where + P: AsRef, + T: NgtObjectType, + { + let _ = NgtIndex::::open(&index_path)?; unsafe { let ebuf = sys::ngt_create_error_object(); @@ -271,12 +291,12 @@ impl GraphOptimizer { } /// Converts the `index_in` ANNG to an ONNG at `index_out`. - fn convert_anng_to_onng>( - &mut self, - index_anng_in: P, - index_onng_out: P, - ) -> Result<()> { - let _ = Index::open(&index_anng_in)?; + fn convert_anng_to_onng(&mut self, index_anng_in: P, index_onng_out: P) -> Result<()> + where + T: NgtObjectType, + P: AsRef, + { + let _ = NgtIndex::::open(&index_anng_in)?; unsafe { let ebuf = sys::ngt_create_error_object(); @@ -308,11 +328,10 @@ mod tests { use std::error::Error as StdError; use std::result::Result as StdResult; + use rand::Rng; use tempfile::tempdir; - use crate::{DistanceType, Index, Properties}; - - use super::*; + use crate::{ngt::optim::*, ngt::*}; #[ignore] #[test] @@ -322,12 +341,13 @@ mod tests { let dir = tempdir()?; // Create an index for vectors of dimension 3 with cosine distance - let prop = Properties::dimension(3)?.distance_type(DistanceType::Cosine)?; - let mut index = Index::create(dir.path(), prop)?; + let prop = NgtProperties::::dimension(3)?.distance_type(NgtDistance::Cosine)?; + let mut index = NgtIndex::create(dir.path(), prop)?; // Populate the index, but don't build it yet - for i in 0..1_000_000 { - let _ = index.insert(vec![i, i + 1, i + 2])?; + let mut rng = rand::thread_rng(); + for _ in 0..25_000 { + index.insert(vec![rng.gen(); 3])?; } index.persist()?; @@ -335,12 +355,12 @@ mod tests { optimize_anng_edges_number(dir.path(), AnngEdgeOptimParams::default())?; // Now build and persist again the optimized index - let mut index = Index::open(dir.path())?; + let mut index = NgtIndex::::open(dir.path())?; index.build(4)?; index.persist()?; // Further optimize the index - optimize_anng_search_parameters(dir.path())?; + optimize_anng_search_parameters::(dir.path())?; dir.close()?; Ok(()) @@ -353,12 +373,13 @@ mod tests { let dir = tempdir()?; // Create an index for vectors of dimension 3 with cosine distance - let prop = Properties::dimension(3)?.distance_type(DistanceType::Cosine)?; - let mut index = Index::create(dir.path(), prop)?; + let prop = NgtProperties::::dimension(3)?.distance_type(NgtDistance::Cosine)?; + let mut index = NgtIndex::create(dir.path(), prop)?; // Populate and build the index - for i in 0..1000 { - let _ = index.insert(vec![i, i + 1, i + 2])?; + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + index.insert(vec![rng.gen(); 3])?; } index.build(4)?; @@ -377,15 +398,16 @@ mod tests { let dir_in = tempdir()?; // Create an index for vectors of dimension 3 with cosine distance - let prop = Properties::dimension(3)? - .distance_type(DistanceType::Cosine)? + let prop = NgtProperties::::dimension(3)? + .distance_type(NgtDistance::Cosine)? .creation_edge_size(100)?; // More than default value, improves the final ONNG - let mut index = Index::create(dir_in.path(), prop)?; + let mut index = NgtIndex::create(dir_in.path(), prop)?; // Populate and persist (but don't build yet) the index - for i in 0..1000 { - let _ = index.insert(vec![i, i + 1, i + 2])?; + let mut rng = rand::thread_rng(); + for _ in 0..25_000 { + index.insert(vec![rng.gen(); 3])?; } index.persist()?; @@ -393,7 +415,7 @@ mod tests { optimize_anng_edges_number(dir_in.path(), AnngEdgeOptimParams::default())?; // Now build and persist again the optimized index - let mut index = Index::open(dir_in.path())?; + let mut index = NgtIndex::::open(dir_in.path())?; index.build(4)?; index.persist()?; @@ -405,7 +427,7 @@ mod tests { let mut params = GraphOptimParams::default(); params.nb_outgoing = 10; params.nb_incoming = 100; // An even larger number of incoming edges can be specified - convert_anng_to_onng(dir_in.path(), dir_out.path(), params)?; + convert_anng_to_onng::(dir_in.path(), dir_out.path(), params)?; dir_out.close()?; dir_in.close()?; diff --git a/src/properties.rs b/src/ngt/properties.rs similarity index 80% rename from src/properties.rs rename to src/ngt/properties.rs index c97413d..8198a93 100644 --- a/src/properties.rs +++ b/src/ngt/properties.rs @@ -1,6 +1,7 @@ -use std::convert::TryFrom; use std::ptr; +use std::{convert::TryFrom, marker::PhantomData}; +use half::f16; use ngt_sys as sys; use num_enum::TryFromPrimitive; use scopeguard::defer; @@ -9,14 +10,44 @@ use crate::error::{make_err, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] #[repr(i32)] -pub enum ObjectType { +pub enum NgtObject { Uint8 = 1, Float = 2, + Float16 = 3, +} + +mod private { + pub trait Sealed {} +} + +pub trait NgtObjectType: private::Sealed { + fn as_obj() -> NgtObject; +} + +impl private::Sealed for f32 {} +impl NgtObjectType for f32 { + fn as_obj() -> NgtObject { + NgtObject::Float + } +} + +impl private::Sealed for u8 {} +impl NgtObjectType for u8 { + fn as_obj() -> NgtObject { + NgtObject::Uint8 + } +} + +impl private::Sealed for f16 {} +impl NgtObjectType for f16 { + fn as_obj() -> NgtObject { + NgtObject::Float16 + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] #[repr(i32)] -pub enum DistanceType { +pub enum NgtDistance { L1 = 0, L2 = 1, Angle = 2, @@ -32,25 +63,29 @@ pub enum DistanceType { } #[derive(Debug)] -pub struct Properties { +pub struct NgtProperties { pub(crate) dimension: i32, pub(crate) creation_edge_size: i16, pub(crate) search_edge_size: i16, - pub(crate) object_type: ObjectType, - pub(crate) distance_type: DistanceType, + pub(crate) object_type: NgtObject, + pub(crate) distance_type: NgtDistance, pub(crate) raw_prop: sys::NGTProperty, + _marker: PhantomData, } -unsafe impl Send for Properties {} -unsafe impl Sync for Properties {} +unsafe impl Send for NgtProperties {} +unsafe impl Sync for NgtProperties {} -impl Properties { +impl NgtProperties +where + T: NgtObjectType, +{ pub fn dimension(dimension: usize) -> Result { let dimension = i32::try_from(dimension)?; let creation_edge_size = 10; let search_edge_size = 40; - let object_type = ObjectType::Float; - let distance_type = DistanceType::L2; + let object_type = T::as_obj(); + let distance_type = NgtDistance::L2; unsafe { let ebuf = sys::ngt_create_error_object(); @@ -74,6 +109,7 @@ impl Properties { object_type, distance_type, raw_prop, + _marker: PhantomData, }) } } @@ -101,6 +137,7 @@ impl Properties { object_type: self.object_type, distance_type: self.distance_type, raw_prop, + _marker: PhantomData, }) } } @@ -138,13 +175,13 @@ impl Properties { if object_type < 0 { Err(make_err(ebuf))? } - let object_type = ObjectType::try_from(object_type)?; + let object_type = NgtObject::try_from(object_type)?; let distance_type = sys::ngt_get_property_distance_type(raw_prop, ebuf); if distance_type < 0 { Err(make_err(ebuf))? } - let distance_type = DistanceType::try_from(distance_type)?; + let distance_type = NgtDistance::try_from(distance_type)?; Ok(Self { dimension, @@ -153,6 +190,7 @@ impl Properties { object_type, distance_type, raw_prop, + _marker: PhantomData, }) } } @@ -204,33 +242,32 @@ impl Properties { Ok(()) } - pub fn object_type(mut self, object_type: ObjectType) -> Result { - self.object_type = object_type; - unsafe { Self::set_object_type(self.raw_prop, object_type)? }; - Ok(self) - } - - unsafe fn set_object_type(raw_prop: sys::NGTProperty, object_type: ObjectType) -> Result<()> { + unsafe fn set_object_type(raw_prop: sys::NGTProperty, object_type: NgtObject) -> Result<()> { let ebuf = sys::ngt_create_error_object(); defer! { sys::ngt_destroy_error_object(ebuf); } match object_type { - ObjectType::Uint8 => { + NgtObject::Uint8 => { if !sys::ngt_set_property_object_type_integer(raw_prop, ebuf) { Err(make_err(ebuf))? } } - ObjectType::Float => { + NgtObject::Float => { if !sys::ngt_set_property_object_type_float(raw_prop, ebuf) { Err(make_err(ebuf))? } } + NgtObject::Float16 => { + if !sys::ngt_set_property_object_type_float16(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } } Ok(()) } - pub fn distance_type(mut self, distance_type: DistanceType) -> Result { + pub fn distance_type(mut self, distance_type: NgtDistance) -> Result { self.distance_type = distance_type; unsafe { Self::set_distance_type(self.raw_prop, distance_type)? }; Ok(self) @@ -238,68 +275,68 @@ impl Properties { unsafe fn set_distance_type( raw_prop: sys::NGTProperty, - distance_type: DistanceType, + distance_type: NgtDistance, ) -> Result<()> { let ebuf = sys::ngt_create_error_object(); defer! { sys::ngt_destroy_error_object(ebuf); } match distance_type { - DistanceType::L1 => { + NgtDistance::L1 => { if !sys::ngt_set_property_distance_type_l1(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::L2 => { + NgtDistance::L2 => { if !sys::ngt_set_property_distance_type_l2(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Angle => { + NgtDistance::Angle => { if !sys::ngt_set_property_distance_type_angle(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Hamming => { + NgtDistance::Hamming => { if !sys::ngt_set_property_distance_type_hamming(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Cosine => { + NgtDistance::Cosine => { if !sys::ngt_set_property_distance_type_cosine(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::NormalizedAngle => { + NgtDistance::NormalizedAngle => { if !sys::ngt_set_property_distance_type_normalized_angle(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::NormalizedCosine => { + NgtDistance::NormalizedCosine => { if !sys::ngt_set_property_distance_type_normalized_cosine(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Jaccard => { + NgtDistance::Jaccard => { if !sys::ngt_set_property_distance_type_jaccard(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::SparseJaccard => { + NgtDistance::SparseJaccard => { if !sys::ngt_set_property_distance_type_sparse_jaccard(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::NormalizedL2 => { + NgtDistance::NormalizedL2 => { if !sys::ngt_set_property_distance_type_normalized_l2(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Poincare => { + NgtDistance::Poincare => { if !sys::ngt_set_property_distance_type_poincare(raw_prop, ebuf) { Err(make_err(ebuf))? } } - DistanceType::Lorentz => { + NgtDistance::Lorentz => { if !sys::ngt_set_property_distance_type_lorentz(raw_prop, ebuf) { Err(make_err(ebuf))? } @@ -310,7 +347,7 @@ impl Properties { } } -impl Drop for Properties { +impl Drop for NgtProperties { fn drop(&mut self) { if !self.raw_prop.is_null() { unsafe { sys::ngt_destroy_property(self.raw_prop) }; diff --git a/src/qbg/index.rs b/src/qbg/index.rs new file mode 100644 index 0000000..a895288 --- /dev/null +++ b/src/qbg/index.rs @@ -0,0 +1,539 @@ +use std::ffi::CString; +use std::marker::PhantomData; +use std::os::unix::ffi::OsStrExt; +use std::path::Path; +use std::{mem, ptr}; + +use half::f16; +use ngt_sys as sys; +use scopeguard::defer; + +use crate::error::{make_err, Error, Result}; +use crate::{SearchResult, VecId}; + +use super::{QbgBuildParams, QbgConstructParams, QbgObject, QbgObjectType}; + +#[derive(Debug)] +pub struct QbgIndex { + pub(crate) index: sys::QBGIndex, + path: CString, + _mode: M, + dimension: u32, + ebuf: sys::NGTError, + _marker: PhantomData, +} + +impl QbgIndex +where + T: QbgObjectType, +{ + pub fn create

(path: P, create_params: QbgConstructParams) -> Result + where + P: AsRef, + { + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let path = CString::new(path.as_ref().as_os_str().as_bytes())?; + + if !sys::qbg_create(path.as_ptr(), &mut create_params.into_raw() as *mut _, ebuf) { + Err(make_err(ebuf))? + } + + let index = sys::qbg_open_index(path.as_ptr(), false, ebuf); + if index.is_null() { + Err(make_err(ebuf))? + } + + let dimension = sys::qbg_get_dimension(index, ebuf) as u32; + if dimension == 0 { + Err(make_err(ebuf))? + } + + Ok(QbgIndex { + index, + path, + _mode: ModeWrite, + dimension, + ebuf: sys::ngt_create_error_object(), + _marker: PhantomData, + }) + } + } + + pub fn insert(&mut self, mut vec: Vec) -> Result { + unsafe { + let id = match T::as_obj() { + QbgObject::Float => sys::qbg_append_object( + self.index, + vec.as_mut_ptr() as *mut _, + self.dimension, + self.ebuf, + ), + QbgObject::Uint8 => sys::qbg_append_object_as_uint8( + self.index, + vec.as_mut_ptr() as *mut _, + self.dimension, + self.ebuf, + ), + QbgObject::Float16 => sys::qbg_append_object_as_float16( + self.index, + vec.as_mut_ptr() as *mut _, + self.dimension, + self.ebuf, + ), + }; + if id == 0 { + Err(make_err(self.ebuf))? + } else { + Ok(id) + } + } + } + + pub fn build(&mut self, build_params: QbgBuildParams) -> Result<()> { + unsafe { + if !sys::qbg_build_index( + self.path.as_ptr(), + &mut build_params.into_raw() as *mut _, + self.ebuf, + ) { + Err(make_err(self.ebuf))? + } + Ok(()) + } + } + + pub fn persist(&mut self) -> Result<()> { + unsafe { + if !sys::qbg_save_index(self.index, self.ebuf) { + Err(make_err(self.ebuf))? + } + Ok(()) + } + } + + pub fn into_readable(self) -> Result> { + let path = self.path.clone(); + drop(self); + QbgIndex::open(path.into_string()?) + } +} + +impl QbgIndex +where + T: QbgObjectType, +{ + pub fn open>(path: P) -> Result { + if !path.as_ref().exists() { + Err(Error(format!("Path {:?} does not exist", path.as_ref())))? + } + + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let path = CString::new(path.as_ref().as_os_str().as_bytes())?; + let index = sys::qbg_open_index(path.as_ptr(), true, ebuf); + if index.is_null() { + Err(make_err(ebuf))? + } + + let dimension = sys::qbg_get_dimension(index, ebuf) as u32; + if dimension == 0 { + Err(make_err(ebuf))? + } + + Ok(QbgIndex { + index, + path, + _mode: ModeRead, + dimension, + ebuf: sys::ngt_create_error_object(), + _marker: PhantomData, + }) + } + } + + pub fn search(&self, query: QbgQuery) -> Result> { + unsafe { + let results = sys::ngt_create_empty_results(self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + defer! { sys::qbg_destroy_results(results); } + + match T::as_obj() { + QbgObject::Float => { + let q = sys::QBGQueryFloat { + query: query.query.as_ptr() as *mut f32, + params: query.params(), + }; + if !sys::qbg_search_index_float(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + QbgObject::Uint8 => { + let q = sys::QBGQueryUint8 { + query: query.query.as_ptr() as *mut u8, + params: query.params(), + }; + if !sys::qbg_search_index_uint8(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + QbgObject::Float16 => { + let q = sys::QBGQueryFloat16 { + query: query.query.as_ptr() as *mut _, + params: query.params(), + }; + if !sys::qbg_search_index_float16(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + } + + let rsize = sys::qbg_get_result_size(results, self.ebuf); + let mut ret = Vec::with_capacity(rsize as usize); + + for i in 0..rsize { + let d = sys::qbg_get_result(results, i, self.ebuf); + if d.id == 0 && d.distance == 0.0 { + Err(make_err(self.ebuf))? + } else { + ret.push(SearchResult { + id: d.id, + distance: d.distance, + }); + } + } + + Ok(ret) + } + } + + pub fn into_writable(self) -> Result> { + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let path = self.path.clone(); + drop(self); + + let index = sys::qbg_open_index(path.as_ptr(), false, ebuf); + if index.is_null() { + Err(make_err(ebuf))? + } + + let dimension = sys::qbg_get_dimension(index, ebuf) as u32; + if dimension == 0 { + Err(make_err(ebuf))? + } + + Ok(QbgIndex { + index, + path, + _mode: ModeWrite, + dimension, + ebuf: sys::ngt_create_error_object(), + _marker: PhantomData, + }) + } + } +} + +impl QbgIndex +where + T: QbgObjectType, + M: IndexMode, +{ + /// Get the specified vector. + pub fn get_vec(&self, id: VecId) -> Result> { + unsafe { + match T::as_obj() { + QbgObject::Float => { + let results = sys::qbg_get_object(self.index, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + let results = Vec::from_raw_parts( + results, + self.dimension as usize, + self.dimension as usize, + ); + Ok(mem::transmute::<_, Vec>(results)) + } + QbgObject::Uint8 => { + let results = sys::qbg_get_object_as_uint8(self.index, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + let results = Vec::from_raw_parts( + results, + self.dimension as usize, + self.dimension as usize, + ); + Ok(mem::transmute::<_, Vec>(results)) + } + QbgObject::Float16 => { + let results = sys::qbg_get_object_as_float16(self.index, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + let results = Vec::from_raw_parts( + results as *mut f16, + self.dimension as usize, + self.dimension as usize, + ); + Ok(mem::transmute::<_, Vec>(results)) + } + } + } + } +} + +impl Drop for QbgIndex { + fn drop(&mut self) { + if !self.index.is_null() { + unsafe { sys::qbg_close_index(self.index) }; + self.index = ptr::null_mut(); + } + if !self.ebuf.is_null() { + unsafe { sys::ngt_destroy_error_object(self.ebuf) }; + self.ebuf = ptr::null_mut(); + } + } +} + +mod private { + pub trait Sealed {} +} + +pub trait IndexMode: private::Sealed {} + +#[derive(Debug, Clone, Copy)] +pub struct ModeRead; + +impl private::Sealed for ModeRead {} +impl IndexMode for ModeRead {} + +#[derive(Debug, Clone, Copy)] +pub struct ModeWrite; + +impl private::Sealed for ModeWrite {} +impl IndexMode for ModeWrite {} + +#[derive(Debug, Clone, PartialEq)] +pub struct QbgQuery<'a, T> { + query: &'a [T], + pub size: u64, + pub epsilon: f32, + pub blob_epsilon: f32, + pub result_expansion: f32, + pub number_of_explored_blobs: u64, + pub number_of_edges: u64, + pub radius: f32, +} + +impl<'a, T> QbgQuery<'a, T> +where + T: QbgObjectType, +{ + pub fn new(query: &'a [T]) -> Self { + Self { + query, + size: 20, + epsilon: 0.1, + blob_epsilon: 0.0, + result_expansion: 3.0, + number_of_explored_blobs: 256, + number_of_edges: 0, + radius: 0.0, + } + } + + pub fn size(mut self, size: u64) -> Self { + self.size = size; + self + } + + pub fn epsilon(mut self, epsilon: f32) -> Self { + self.epsilon = epsilon; + self + } + + pub fn blob_epsilon(mut self, blob_epsilon: f32) -> Self { + self.blob_epsilon = blob_epsilon; + self + } + + pub fn result_expansion(mut self, result_expansion: f32) -> Self { + self.result_expansion = result_expansion; + self + } + + pub fn number_of_explored_blobs(mut self, number_of_explored_blobs: u64) -> Self { + self.number_of_explored_blobs = number_of_explored_blobs; + self + } + + pub fn number_of_edges(mut self, number_of_edges: u64) -> Self { + self.number_of_edges = number_of_edges; + self + } + + pub fn radius(mut self, radius: f32) -> Self { + self.radius = radius; + self + } + + unsafe fn params(&self) -> sys::QBGQueryParameters { + sys::QBGQueryParameters { + number_of_results: self.size, + epsilon: self.epsilon, + blob_epsilon: self.blob_epsilon, + result_expansion: self.result_expansion, + number_of_explored_blobs: self.number_of_explored_blobs, + number_of_edges: self.number_of_edges, + radius: self.radius, + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error as StdError; + use std::iter::repeat; + use std::result::Result as StdResult; + + use tempfile::tempdir; + + use super::*; + + #[test] + fn test_qbg_f32() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + std::fs::remove_dir(dir.path())?; + + // Create a QGB index + let ndims = 3; + let mut index = QbgIndex::create(dir.path(), QbgConstructParams::dimension(ndims))?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims as usize) + .map(|i| i as f32) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| j as f32)) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(QbgBuildParams::default())?; + index.persist()?; + + let index = index.into_readable()?; + + // Perform a vector search (with 2 results) + let v: Vec = (1..=ndims).into_iter().map(|x| x as f32).collect(); + let query = QbgQuery::new(&v).size(2); + let res = index.search(query)?; + assert_eq!(ids[0], res[0].id); + assert_eq!(v, index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_qbg_f16() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + std::fs::remove_dir(dir.path())?; + + // Create a QGB index + let ndims = 3; + let mut index = QbgIndex::create(dir.path(), QbgConstructParams::dimension(ndims))?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims as usize) + .map(|i| f16::from_f32(i as f32)) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| f16::from_f32(j as f32))) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(QbgBuildParams::default())?; + index.persist()?; + + let index = index.into_readable()?; + + // Perform a vector search (with 2 results) + let v: Vec = (1..=ndims) + .into_iter() + .map(|x| f16::from_f32(x as f32)) + .collect(); + let query = QbgQuery::new(&v).size(2); + let res = index.search(query)?; + assert_eq!(ids[0], res[0].id); + assert_eq!(v, index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_qbg_u8() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + std::fs::remove_dir(dir.path())?; + + // Create a QGB index + let ndims = 3; + let mut index = QbgIndex::create(dir.path(), QbgConstructParams::dimension(ndims))?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims as usize) + .map(|i| i as u8) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| j as u8)) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(QbgBuildParams::default())?; + index.persist()?; + + let index = index.into_readable()?; + + // Perform a vector search (with 3 results) + let v: Vec = (1..=ndims).into_iter().map(|x| x as u8).collect(); + let query = QbgQuery::new(&v).size(3); + let res = index.search(query)?; + assert!(Vec::from_iter(res[0..3].iter().map(|r| r.id)).contains(&ids[0])); + assert!(v == index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } +} diff --git a/src/qbg/mod.rs b/src/qbg/mod.rs new file mode 100644 index 0000000..56a7aad --- /dev/null +++ b/src/qbg/mod.rs @@ -0,0 +1,71 @@ +//! Quantized blob graph index (QBG Index) +//! +//! ## Defining the properties of a new QBG index: +//! +//! ```rust +//! # fn main() -> Result<(), ngt::Error> { +//! use ngt::qbg::{QbgConstructParams, QbgDistance}; +//! +//! // Defaut parameters with vectors of dimension 3 +//! let params = QbgConstructParams::::dimension(3); +//! +//! // Or customize values (here are the defaults) +//! let params = QbgConstructParams::::dimension(3) +//! .extended_dimension(16)? // next multiple of 16 after 3 +//! .number_of_subvectors(1) +//! .number_of_subvectors(0) +//! .distance_type(QbgDistance::L2); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Creating/Opening a QBG index and using it: +//! +//! ```rust +//! # fn main() -> Result<(), ngt::Error> { +//! # std::fs::create_dir_all("target/path/to/qbg_index").unwrap(); +//! use ngt::qbg::{ +//! ModeRead, ModeWrite, QbgBuildParams, QbgConstructParams, QbgDistance, QbgIndex, QbgQuery, +//! }; +//! +//! // Create a new index +//! let params = QbgConstructParams::dimension(3); +//! let mut index: QbgIndex = +//! QbgIndex::create("target/path/to/qbg_index/dir", params)?; +//! +//! // Insert vectors and get their id +//! let vec1 = vec![1.0, 2.0, 3.0]; +//! let vec2 = vec![4.0, 5.0, 6.0]; +//! let id1 = index.insert(vec1)?; +//! let id2 = index.insert(vec2)?; +//! +//! // Add enough dummy vectors to build an index +//! for i in 0..64 { +//! index.insert(vec![100. + i as f32; 3])?; +//! } +//! // Build the index in RAM and persist it on disk +//! index.build(QbgBuildParams::default())?; +//! index.persist()?; +//! +//! // Open an existing index +//! let index: QbgIndex = QbgIndex::open("target/path/to/qbg_index/dir")?; +//! +//! // Perform a vector search (with 1 result) +//! let query = vec![1.1, 2.1, 3.1]; +//! let res = index.search(QbgQuery::new(&query).size(1))?; +//! assert_eq!(res[0].id, id1); +//! assert_eq!(index.get_vec(id1)?, vec![1.0, 2.0, 3.0]); +//! +//! # std::fs::remove_dir_all("target/path/to/qbg_index").unwrap(); +//! # Ok(()) +//! # } +//! ``` + +mod index; +mod properties; + +pub use self::index::{IndexMode, ModeRead, ModeWrite, QbgIndex, QbgQuery}; +pub use self::properties::{ + QbgBuildParams, QbgConstructParams, QbgDistance, QbgObject, QbgObjectType, +}; diff --git a/src/qbg/properties.rs b/src/qbg/properties.rs new file mode 100644 index 0000000..9d358b5 --- /dev/null +++ b/src/qbg/properties.rs @@ -0,0 +1,299 @@ +use std::marker::PhantomData; + +use half::f16; +use ngt_sys as sys; +use num_enum::TryFromPrimitive; + +use crate::error::Error; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(i32)] +pub enum QbgObject { + Uint8 = 0, + Float = 1, + Float16 = 2, +} + +mod private { + pub trait Sealed {} +} + +pub trait QbgObjectType: private::Sealed { + fn as_obj() -> QbgObject; +} + +impl private::Sealed for f32 {} +impl QbgObjectType for f32 { + fn as_obj() -> QbgObject { + QbgObject::Float + } +} + +impl private::Sealed for u8 {} +impl QbgObjectType for u8 { + fn as_obj() -> QbgObject { + QbgObject::Uint8 + } +} + +impl private::Sealed for f16 {} +impl QbgObjectType for f16 { + fn as_obj() -> QbgObject { + QbgObject::Float16 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(i32)] +pub enum QbgDistance { + L2 = 1, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct QbgConstructParams { + extended_dimension: u64, + dimension: u64, + number_of_subvectors: u64, + number_of_blobs: u64, + internal_data_type: QbgObject, + data_type: QbgObject, + distance_type: QbgDistance, + _marker: PhantomData, +} + +impl QbgConstructParams +where + T: QbgObjectType, +{ + pub fn dimension(dimension: u64) -> Self { + let extended_dimension = next_multiple_of_16(dimension); + let number_of_subvectors = 1; + let number_of_blobs = 0; + let internal_data_type = T::as_obj(); + let data_type = T::as_obj(); + let distance_type = QbgDistance::L2; + + Self { + extended_dimension, + dimension, + number_of_subvectors, + number_of_blobs, + internal_data_type, + data_type, + distance_type, + _marker: PhantomData, + } + } + + pub fn extended_dimension(mut self, extended_dimension: u64) -> Result { + if extended_dimension % 16 == 0 && extended_dimension >= self.dimension { + self.extended_dimension = extended_dimension; + Ok(self) + } else { + Err(Error(format!( + "Invalid extended_dimension: {}, must be a multiple of 16 greater or equal to dimension", + extended_dimension + ))) + } + } + + pub fn number_of_subvectors(mut self, number_of_subvectors: u64) -> Self { + self.number_of_subvectors = number_of_subvectors; + self + } + + pub fn number_of_blobs(mut self, number_of_blobs: u64) -> Self { + self.number_of_blobs = number_of_blobs; + self + } + + pub fn internal_data_type(mut self, internal_data_type: QbgObject) -> Self { + self.internal_data_type = internal_data_type; + self + } + + pub fn distance_type(mut self, distance_type: QbgDistance) -> Self { + self.distance_type = distance_type; + self + } + + pub(crate) unsafe fn into_raw(self) -> sys::QBGConstructionParameters { + sys::QBGConstructionParameters { + extended_dimension: self.extended_dimension, + dimension: self.dimension, + number_of_subvectors: self.number_of_subvectors, + number_of_blobs: self.number_of_blobs, + internal_data_type: self.internal_data_type as i32, + data_type: self.data_type as i32, + distance_type: self.distance_type as i32, + } + } +} + +fn next_multiple_of_16(x: u64) -> u64 { + ((x + 15) / 16) * 16 +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(i32)] +pub enum QbgClusteringInitMode { + Head = 0, + Random = 1, + KmeansPlusPlus = 2, + RandomFixedSeed = 3, + KmeansPlusPlusFixedSeed = 4, + Best = 5, +} + +#[derive(Debug, Clone)] +pub struct QbgBuildParams { + // hierarchical kmeans + hierarchical_clustering_init_mode: QbgClusteringInitMode, + number_of_first_objects: u64, + number_of_first_clusters: u64, + number_of_second_objects: u64, + number_of_second_clusters: u64, + number_of_third_clusters: u64, + // optimization + number_of_objects: u64, + number_of_subvectors: u64, + optimization_clustering_init_mode: QbgClusteringInitMode, + rotation_iteration: u64, + subvector_iteration: u64, + number_of_matrices: u64, + rotation: bool, + repositioning: bool, +} + +impl Default for QbgBuildParams { + fn default() -> Self { + Self { + hierarchical_clustering_init_mode: QbgClusteringInitMode::KmeansPlusPlus, + number_of_first_objects: 0, + number_of_first_clusters: 0, + number_of_second_objects: 0, + number_of_second_clusters: 0, + number_of_third_clusters: 0, + number_of_objects: 1000, + number_of_subvectors: 1, + optimization_clustering_init_mode: QbgClusteringInitMode::KmeansPlusPlus, + rotation_iteration: 2000, + subvector_iteration: 400, + number_of_matrices: 3, + rotation: true, + repositioning: false, + } + } +} + +impl QbgBuildParams { + pub fn hierarchical_clustering_init_mode( + mut self, + clustering_init_mode: QbgClusteringInitMode, + ) -> Self { + self.hierarchical_clustering_init_mode = clustering_init_mode; + self + } + + pub fn number_of_first_objects(mut self, number_of_first_objects: u64) -> Self { + self.number_of_first_objects = number_of_first_objects; + self + } + + pub fn number_of_first_clusters(mut self, number_of_first_clusters: u64) -> Self { + self.number_of_first_clusters = number_of_first_clusters; + self + } + + pub fn number_of_second_objects(mut self, number_of_second_objects: u64) -> Self { + self.number_of_second_objects = number_of_second_objects; + self + } + + pub fn number_of_second_clusters(mut self, number_of_second_clusters: u64) -> Self { + self.number_of_second_clusters = number_of_second_clusters; + self + } + + pub fn number_of_third_clusters(mut self, number_of_third_clusters: u64) -> Self { + self.number_of_third_clusters = number_of_third_clusters; + self + } + + pub fn number_of_objects(mut self, number_of_objects: u64) -> Self { + self.number_of_objects = number_of_objects; + self + } + pub fn number_of_subvectors(mut self, number_of_subvectors: u64) -> Self { + self.number_of_subvectors = number_of_subvectors; + self + } + pub fn optimization_clustering_init_mode( + mut self, + clustering_init_mode: QbgClusteringInitMode, + ) -> Self { + self.optimization_clustering_init_mode = clustering_init_mode; + self + } + + pub fn rotation_iteration(mut self, rotation_iteration: u64) -> Self { + self.rotation_iteration = rotation_iteration; + self + } + + pub fn subvector_iteration(mut self, subvector_iteration: u64) -> Self { + self.subvector_iteration = subvector_iteration; + self + } + + pub fn number_of_matrices(mut self, number_of_matrices: u64) -> Self { + self.number_of_matrices = number_of_matrices; + self + } + + pub fn rotation(mut self, rotation: bool) -> Self { + self.rotation = rotation; + self + } + + pub fn repositioning(mut self, repositioning: bool) -> Self { + self.repositioning = repositioning; + self + } + + pub(crate) unsafe fn into_raw(self) -> sys::QBGBuildParameters { + sys::QBGBuildParameters { + hierarchical_clustering_init_mode: self.hierarchical_clustering_init_mode as i32, + number_of_first_objects: self.number_of_first_objects, + number_of_first_clusters: self.number_of_first_clusters, + number_of_second_objects: self.number_of_second_objects, + number_of_second_clusters: self.number_of_second_clusters, + number_of_third_clusters: self.number_of_third_clusters, + number_of_objects: self.number_of_objects, + number_of_subvectors: self.number_of_subvectors, + optimization_clustering_init_mode: self.optimization_clustering_init_mode as i32, + rotation_iteration: self.rotation_iteration, + subvector_iteration: self.subvector_iteration, + number_of_matrices: self.number_of_matrices, + rotation: self.rotation, + repositioning: self.repositioning, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_qbg_params() { + let params = QbgConstructParams::::dimension(3); + assert_eq!(params.extended_dimension, 16); + + let params = QbgConstructParams::::dimension(16); + assert_eq!(params.extended_dimension, 16); + + let params = QbgConstructParams::::dimension(513); + assert_eq!(params.extended_dimension, 528); + } +} diff --git a/src/qg/index.rs b/src/qg/index.rs new file mode 100644 index 0000000..be68cb3 --- /dev/null +++ b/src/qg/index.rs @@ -0,0 +1,420 @@ +use std::ffi::CString; +use std::mem; +use std::os::unix::ffi::OsStrExt; +use std::path::Path; +use std::ptr; + +use half::f16; +use ngt_sys as sys; +use scopeguard::defer; + +use super::{QgObject, QgObjectType, QgProperties, QgQuantizationParams}; +use crate::error::{make_err, Error, Result}; +use crate::ngt::NgtIndex; +use crate::qg::QgDistance; +use crate::{SearchResult, VecId}; + +#[derive(Debug)] +pub struct QgIndex { + pub(crate) prop: QgProperties, + pub(crate) index: sys::NGTQGIndex, + ebuf: sys::NGTError, +} + +impl QgIndex +where + T: QgObjectType, +{ + /// Quantize an NGT index + pub fn quantize(index: NgtIndex, params: QgQuantizationParams) -> Result { + QgDistance::try_from(index.prop.distance_type)?; + + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let path = index.path.clone(); + drop(index); // Close the index + if !sys::ngtqg_quantize(path.as_ptr(), params.into_raw(), ebuf) { + Err(make_err(ebuf))? + } + + QgIndex::open(path.into_string()?) + } + } + + /// Open the already existing quantized index at the specified path. + pub fn open>(path: P) -> Result { + if !path.as_ref().exists() { + Err(Error(format!("Path {:?} does not exist", path.as_ref())))? + } + + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let path = CString::new(path.as_ref().as_os_str().as_bytes())?; + + let index = sys::ngtqg_open_index(path.as_ptr(), ebuf); + if index.is_null() { + Err(make_err(ebuf))? + } + + let prop = QgProperties::from(index)?; + + Ok(QgIndex { + prop, + index, + ebuf: sys::ngt_create_error_object(), + }) + } + } + + pub fn search(&self, query: QgQuery) -> Result> { + unsafe { + let results = sys::ngt_create_empty_results(self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + defer! { sys::ngt_destroy_results(results); } + + match T::as_obj() { + QgObject::Float => { + let q = sys::NGTQGQueryFloat { + query: query.query.as_ptr() as *mut f32, + params: query.params(), + }; + if !sys::ngtqg_search_index_float(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + QgObject::Uint8 => { + let q = sys::NGTQGQueryUint8 { + query: query.query.as_ptr() as *mut u8, + params: query.params(), + }; + if !sys::ngtqg_search_index_uint8(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + QgObject::Float16 => { + let q = sys::NGTQGQueryFloat16 { + query: query.query.as_ptr() as *mut _, + params: query.params(), + }; + if !sys::ngtqg_search_index_float16(self.index, q, results, self.ebuf) { + Err(make_err(self.ebuf))? + } + } + } + + let rsize = sys::ngt_get_result_size(results, self.ebuf); + let mut ret = Vec::with_capacity(rsize as usize); + + for i in 0..rsize { + let d = sys::ngt_get_result(results, i, self.ebuf); + if d.id == 0 && d.distance == 0.0 { + Err(make_err(self.ebuf))? + } else { + ret.push(SearchResult { + id: d.id, + distance: d.distance, + }); + } + } + + Ok(ret) + } + } + + /// Get the specified vector. + pub fn get_vec(&self, id: VecId) -> Result> { + unsafe { + match self.prop.object_type { + QgObject::Float => { + let ospace = sys::ngt_get_object_space(self.index, self.ebuf); + if ospace.is_null() { + Err(make_err(self.ebuf))? + } + + let results = sys::ngt_get_object_as_float(ospace, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + + let results = Vec::from_raw_parts( + results, + self.prop.dimension as usize, + self.prop.dimension as usize, + ); + let results = mem::ManuallyDrop::new(results); + + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) + } + QgObject::Uint8 => { + let ospace = sys::ngt_get_object_space(self.index, self.ebuf); + if ospace.is_null() { + Err(make_err(self.ebuf))? + } + + let results = sys::ngt_get_object_as_integer(ospace, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + + let results = Vec::from_raw_parts( + results, + self.prop.dimension as usize, + self.prop.dimension as usize, + ); + let results = mem::ManuallyDrop::new(results); + + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) + } + QgObject::Float16 => { + let ospace = sys::ngt_get_object_space(self.index, self.ebuf); + if ospace.is_null() { + Err(make_err(self.ebuf))? + } + + let results = sys::ngt_get_object_as_float16(ospace, id, self.ebuf); + if results.is_null() { + Err(make_err(self.ebuf))? + } + + let results = Vec::from_raw_parts( + results as *mut f16, + self.prop.dimension as usize, + self.prop.dimension as usize, + ); + let results = mem::ManuallyDrop::new(results); + + let results = results.iter().copied().collect::>(); + Ok(mem::transmute::<_, Vec>(results)) + } + } + } + } +} + +impl Drop for QgIndex { + fn drop(&mut self) { + if !self.index.is_null() { + unsafe { sys::ngtqg_close_index(self.index) }; + self.index = ptr::null_mut(); + } + if !self.ebuf.is_null() { + unsafe { sys::ngt_destroy_error_object(self.ebuf) }; + self.ebuf = ptr::null_mut(); + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct QgQuery<'a, T> { + query: &'a [T], + pub size: u64, + pub epsilon: f32, + pub result_expansion: f32, + pub radius: f32, +} + +impl<'a, T> QgQuery<'a, T> +where + T: QgObjectType, +{ + pub fn new(query: &'a [T]) -> Self { + Self { + query, + size: 20, + epsilon: 0.03, + result_expansion: 3.0, + radius: f32::MAX, + } + } + + pub fn size(mut self, size: u64) -> Self { + self.size = size; + self + } + + pub fn epsilon(mut self, epsilon: f32) -> Self { + self.epsilon = epsilon; + self + } + + pub fn result_expansion(mut self, result_expansion: f32) -> Self { + self.result_expansion = result_expansion; + self + } + + pub fn radius(mut self, radius: f32) -> Self { + self.radius = radius; + self + } + + unsafe fn params(&self) -> sys::NGTQGQueryParameters { + sys::NGTQGQueryParameters { + size: self.size, + epsilon: self.epsilon, + result_expansion: self.result_expansion, + radius: self.radius, + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error as StdError; + use std::iter::repeat; + use std::result::Result as StdResult; + + use tempfile::tempdir; + + use crate::qg::QgDistance; + + use super::*; + + #[test] + fn test_qg_f32() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + + // Create an NGT index for vectors + let ndims = 3; + let props = QgProperties::::dimension(ndims)?.distance_type(QgDistance::L2)?; + let mut index = NgtIndex::create(dir.path(), props.try_into()?)?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims) + .map(|i| i as f32) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| j as f32)) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(1)?; + index.persist()?; + + // Quantize the index + let params = QgQuantizationParams { + dimension_of_subvector: 1., + max_number_of_edges: 50, + }; + let index = QgIndex::quantize(index, params)?; + + // Perform a vector search (with 3 results) + let v: Vec = (1..=ndims).into_iter().map(|x| x as f32).collect(); + let query = QgQuery::new(&v).size(3); + let res = index.search(query)?; + assert!(ids[0] == res[0].id); + assert!(v == index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_qg_f16() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + + // Create an NGT index for vectors + let ndims = 3; + let props = QgProperties::::dimension(ndims)?.distance_type(QgDistance::L2)?; + let mut index = NgtIndex::create(dir.path(), props.try_into()?)?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims) + .map(|i| f16::from_f32(i as f32)) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| f16::from_f32(j as f32))) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(1)?; + index.persist()?; + + // Quantize the index + let params = QgQuantizationParams { + dimension_of_subvector: 1., + max_number_of_edges: 50, + }; + let index = QgIndex::quantize(index, params)?; + + // Perform a vector search (with 3 results) + let v: Vec = (1..=ndims) + .into_iter() + .map(|x| f16::from_f32(x as f32)) + .collect(); + let query = QgQuery::new(&v).size(3); + let res = index.search(query)?; + assert!(ids[0] == res[0].id); + assert!(v == index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } + + #[test] + fn test_qg_u8() -> StdResult<(), Box> { + // Get a temporary directory to store the index + let dir = tempdir()?; + + // Create an NGT index for vectors + let ndims = 3; + let props = QgProperties::::dimension(ndims)?.distance_type(QgDistance::L2)?; + let mut index = NgtIndex::create(dir.path(), props.try_into()?)?; + + // Insert vectors and get their ids + let nvecs = 64; + let ids = (1..ndims * nvecs) + .step_by(ndims) + .map(|i| i as u8) + .map(|i| { + repeat(i) + .zip((0..ndims).map(|j| j as u8)) + .map(|(i, j)| i + j) + .collect() + }) + .map(|vector| index.insert(vector)) + .collect::>>()?; + + // Build and persist the index + index.build(1)?; + index.persist()?; + + // Quantize the index + let params = QgQuantizationParams { + dimension_of_subvector: 1., + max_number_of_edges: 50, + }; + let index = QgIndex::quantize(index, params)?; + + // Perform a vector search (with 3 results) + let v: Vec = (1..=ndims).into_iter().map(|x| x as u8).collect(); + let query = QgQuery::new(&v).size(3); + let res = &index.search(query)?; + assert!(Vec::from_iter(res[0..3].iter().map(|r| r.id)).contains(&ids[0])); + assert!(v == index.get_vec(ids[0])?); + + dir.close()?; + Ok(()) + } +} diff --git a/src/qg/mod.rs b/src/qg/mod.rs new file mode 100644 index 0000000..9c22c13 --- /dev/null +++ b/src/qg/mod.rs @@ -0,0 +1,75 @@ +//! Quantized graph index (QG Index) +//! +//! ## Defining the properties of a new QG index: +//! +//! ```rust +//! # fn main() -> Result<(), ngt::Error> { +//! use ngt::qg::{QgProperties, QgDistance}; +//! +//! // Defaut properties with vectors of dimension 3 +//! let prop = QgProperties::::dimension(3)?; +//! +//! // Or customize values (here are the defaults) +//! let prop = QgProperties::::dimension(3)? +//! .creation_edge_size(10)? +//! .search_edge_size(40)? +//! .distance_type(QgDistance::L2)?; +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Creating/Opening a QG index and using it: +//! +//! ```rust +//! # fn main() -> Result<(), ngt::Error> { +//! use ngt::NgtIndex; +//! use ngt::qg::{QgDistance, QgIndex, QgProperties, QgQuantizationParams, QgQuery}; +//! +//! // Create a new quantizable NGT index +//! let prop = QgProperties::dimension(3)?; +//! let mut index: NgtIndex = +//! NgtIndex::create("target/path/to/qg_index/dir", prop.try_into()?)?; +//! +//! // Insert two vectors and get their id +//! let vec1 = vec![1.0, 2.0, 3.0]; +//! let vec2 = vec![4.0, 5.0, 6.0]; +//! let id1 = index.insert(vec1)?; +//! let id2 = index.insert(vec2)?; +//! +//! // Add enough dummy vectors to build an index +//! for i in 0..64 { +//! index.insert(vec![100. + i as f32; 3])?; +//! } +//! // Build the index in RAM and persist it on disk +//! index.build(1)?; +//! index.persist()?; +//! +//! // Quantize the NGT index +//! let params = QgQuantizationParams { +//! dimension_of_subvector: 1., +//! max_number_of_edges: 50, +//! }; +//! let index = QgIndex::quantize(index, params)?; +//! +//! // Open an existing QG index +//! let index = QgIndex::open("target/path/to/qg_index/dir")?; +//! +//! // Perform a vector search (with 1 result) +//! let query = vec![1.1, 2.1, 3.1]; +//! let res = index.search(QgQuery::new(&query).size(1))?; +//! assert_eq!(res[0].id, id1); +//! assert_eq!(index.get_vec(id1)?, vec![1.0, 2.0, 3.0]); +//! +//! # std::fs::remove_dir_all("target/path/to/qg_index/dir").unwrap(); +//! # Ok(()) +//! # } +//! ``` + +mod index; +mod properties; + +pub use self::index::{QgIndex, QgQuery}; +pub use self::properties::{ + QgDistance, QgObject, QgObjectType, QgProperties, QgQuantizationParams, +}; diff --git a/src/qg/properties.rs b/src/qg/properties.rs new file mode 100644 index 0000000..53abda3 --- /dev/null +++ b/src/qg/properties.rs @@ -0,0 +1,359 @@ +use std::marker::PhantomData; +use std::ptr; + +use half::f16; +use ngt_sys as sys; +use num_enum::TryFromPrimitive; +use scopeguard::defer; + +use crate::error::{make_err, Result}; +use crate::ngt::NgtObjectType; +use crate::{NgtDistance, NgtProperties}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(i32)] +pub enum QgObject { + Uint8 = 1, + Float = 2, + Float16 = 3, +} + +mod private { + pub trait Sealed {} +} + +pub trait QgObjectType: private::Sealed { + fn as_obj() -> QgObject; +} + +impl private::Sealed for f32 {} +impl QgObjectType for f32 { + fn as_obj() -> QgObject { + QgObject::Float + } +} + +impl private::Sealed for u8 {} +impl QgObjectType for u8 { + fn as_obj() -> QgObject { + QgObject::Uint8 + } +} + +impl private::Sealed for f16 {} +impl QgObjectType for f16 { + fn as_obj() -> QgObject { + QgObject::Float16 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(i32)] +pub enum QgDistance { + L2 = 1, + Cosine = 4, +} + +impl From for NgtDistance { + fn from(d: QgDistance) -> Self { + match d { + QgDistance::L2 => NgtDistance::L2, + QgDistance::Cosine => NgtDistance::Cosine, + } + } +} + +impl TryFrom for QgDistance { + type Error = crate::Error; + + fn try_from(d: NgtDistance) -> Result { + match d { + NgtDistance::L2 => Ok(QgDistance::L2), + NgtDistance::Cosine => Ok(QgDistance::Cosine), + _ => Err(format!("Invalid distance {d:?} isn't supported for QG").into()), + } + } +} + +#[derive(Debug)] +pub struct QgProperties { + pub(crate) dimension: i32, + pub(crate) creation_edge_size: i16, + pub(crate) search_edge_size: i16, + pub(crate) object_type: QgObject, + pub(crate) distance_type: QgDistance, + pub(crate) raw_prop: sys::NGTProperty, + _marker: PhantomData, +} + +unsafe impl Send for QgProperties {} +unsafe impl Sync for QgProperties {} + +impl QgProperties +where + T: QgObjectType, +{ + pub fn dimension(dimension: usize) -> Result { + let dimension = i32::try_from(dimension)?; + let creation_edge_size = 10; + let search_edge_size = 40; + let object_type = T::as_obj(); + let distance_type = QgDistance::L2; + + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let raw_prop = sys::ngt_create_property(ebuf); + if raw_prop.is_null() { + Err(make_err(ebuf))? + } + + Self::set_dimension(raw_prop, dimension)?; + Self::set_creation_edge_size(raw_prop, creation_edge_size)?; + Self::set_search_edge_size(raw_prop, search_edge_size)?; + Self::set_object_type(raw_prop, object_type)?; + Self::set_distance_type(raw_prop, distance_type)?; + + Ok(Self { + dimension, + creation_edge_size, + search_edge_size, + object_type, + distance_type, + raw_prop, + _marker: PhantomData, + }) + } + } + + pub fn try_clone(&self) -> Result { + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let raw_prop = sys::ngt_create_property(ebuf); + if raw_prop.is_null() { + Err(make_err(ebuf))? + } + + Self::set_dimension(raw_prop, self.dimension)?; + Self::set_creation_edge_size(raw_prop, self.creation_edge_size)?; + Self::set_search_edge_size(raw_prop, self.search_edge_size)?; + Self::set_object_type(raw_prop, self.object_type)?; + Self::set_distance_type(raw_prop, self.distance_type)?; + + Ok(Self { + dimension: self.dimension, + creation_edge_size: self.creation_edge_size, + search_edge_size: self.search_edge_size, + object_type: self.object_type, + distance_type: self.distance_type, + raw_prop, + _marker: PhantomData, + }) + } + } + + pub(crate) fn from(index: sys::NGTIndex) -> Result { + unsafe { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + let raw_prop = sys::ngt_create_property(ebuf); + if raw_prop.is_null() { + Err(make_err(ebuf))? + } + + if !sys::ngt_get_property(index, raw_prop, ebuf) { + Err(make_err(ebuf))? + } + + let dimension = sys::ngt_get_property_dimension(raw_prop, ebuf); + if dimension < 0 { + Err(make_err(ebuf))? + } + + let creation_edge_size = sys::ngt_get_property_edge_size_for_creation(raw_prop, ebuf); + if creation_edge_size < 0 { + Err(make_err(ebuf))? + } + + let search_edge_size = sys::ngt_get_property_edge_size_for_search(raw_prop, ebuf); + if search_edge_size < 0 { + Err(make_err(ebuf))? + } + + let object_type = sys::ngt_get_property_object_type(raw_prop, ebuf); + if object_type < 0 { + Err(make_err(ebuf))? + } + let object_type = QgObject::try_from(object_type)?; + + let distance_type = sys::ngt_get_property_distance_type(raw_prop, ebuf); + if distance_type < 0 { + Err(make_err(ebuf))? + } + let distance_type = QgDistance::try_from(distance_type)?; + + Ok(Self { + dimension, + creation_edge_size, + search_edge_size, + object_type, + distance_type, + raw_prop, + _marker: PhantomData, + }) + } + } + + unsafe fn set_dimension(raw_prop: sys::NGTProperty, dimension: i32) -> Result<()> { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + if !sys::ngt_set_property_dimension(raw_prop, dimension, ebuf) { + Err(make_err(ebuf))? + } + + Ok(()) + } + + pub fn creation_edge_size(mut self, size: usize) -> Result { + let size = i16::try_from(size)?; + self.creation_edge_size = size; + unsafe { Self::set_creation_edge_size(self.raw_prop, size)? }; + Ok(self) + } + + unsafe fn set_creation_edge_size(raw_prop: sys::NGTProperty, size: i16) -> Result<()> { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + if !sys::ngt_set_property_edge_size_for_creation(raw_prop, size, ebuf) { + Err(make_err(ebuf))? + } + + Ok(()) + } + + pub fn search_edge_size(mut self, size: usize) -> Result { + let size = i16::try_from(size)?; + self.search_edge_size = size; + unsafe { Self::set_search_edge_size(self.raw_prop, size)? }; + Ok(self) + } + + unsafe fn set_search_edge_size(raw_prop: sys::NGTProperty, size: i16) -> Result<()> { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + if !sys::ngt_set_property_edge_size_for_search(raw_prop, size, ebuf) { + Err(make_err(ebuf))? + } + + Ok(()) + } + + unsafe fn set_object_type(raw_prop: sys::NGTProperty, object_type: QgObject) -> Result<()> { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + match object_type { + QgObject::Uint8 => { + if !sys::ngt_set_property_object_type_integer(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } + QgObject::Float => { + if !sys::ngt_set_property_object_type_float(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } + QgObject::Float16 => { + if !sys::ngt_set_property_object_type_float16(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } + } + + Ok(()) + } + + pub fn distance_type(mut self, distance_type: QgDistance) -> Result { + self.distance_type = distance_type; + unsafe { Self::set_distance_type(self.raw_prop, distance_type)? }; + Ok(self) + } + + unsafe fn set_distance_type( + raw_prop: sys::NGTProperty, + distance_type: QgDistance, + ) -> Result<()> { + let ebuf = sys::ngt_create_error_object(); + defer! { sys::ngt_destroy_error_object(ebuf); } + + match distance_type { + QgDistance::L2 => { + if !sys::ngt_set_property_distance_type_l2(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } + QgDistance::Cosine => { + if !sys::ngt_set_property_distance_type_cosine(raw_prop, ebuf) { + Err(make_err(ebuf))? + } + } + } + + Ok(()) + } +} + +impl Drop for QgProperties { + fn drop(&mut self) { + if !self.raw_prop.is_null() { + unsafe { sys::ngt_destroy_property(self.raw_prop) }; + self.raw_prop = ptr::null_mut(); + } + } +} + +impl TryFrom> for NgtProperties +where + T: QgObjectType, + T: NgtObjectType, +{ + type Error = crate::Error; + + fn try_from(prop: QgProperties) -> Result { + NgtProperties::dimension(prop.dimension as usize)? + .creation_edge_size(prop.creation_edge_size as usize)? + .search_edge_size(prop.search_edge_size as usize)? + .distance_type(prop.distance_type.into()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct QgQuantizationParams { + pub dimension_of_subvector: f32, + pub max_number_of_edges: u64, +} + +impl Default for QgQuantizationParams { + fn default() -> Self { + Self { + dimension_of_subvector: 0.0, + max_number_of_edges: 128, + } + } +} + +impl QgQuantizationParams { + pub(crate) fn into_raw(self) -> sys::NGTQGQuantizationParameters { + sys::NGTQGQuantizationParameters { + dimension_of_subvector: self.dimension_of_subvector, + max_number_of_edges: self.max_number_of_edges, + } + } +}