Skip to content

Commit

Permalink
Indexes are now templated by their content types
Browse files Browse the repository at this point in the history
  • Loading branch information
lerouxrgd committed Jun 9, 2023
1 parent cddfc42 commit 8dd09ac
Show file tree
Hide file tree
Showing 15 changed files with 574 additions and 407 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"

[dependencies]
half = "2"
ngt-sys = { path = "ngt-sys", version = "2.0.11" }
ngt-sys = { path = "ngt-sys", version = "2.0.12" }
num_enum = "0.5"
scopeguard = "1"

Expand Down
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ten to several thousand dimensions).

This crate provides the following indexes:
* `NgtIndex`: Graph and tree-based index[^1]
* `QqIndex`: Quantized graph-based index[^2]
* `QgIndex`: Quantized graph-based index[^2]
* `QbgIndex`: Quantized blob graph-based index

The quantized indexes are available through the `quantized` Cargo feature. Note that
Expand All @@ -37,16 +37,15 @@ features are available through the features `shared_mem` and `large_data` respec
Defining the properties of a new index:

```rust,ignore
use ngt::{NgtProperties, NgtDistance, NgtObject};
use ngt::{NgtProperties, NgtDistance};
// Defaut properties with vectors of dimension 3
let prop = NgtProperties::dimension(3)?;
let prop = NgtProperties::<f32>::dimension(3)?;
// Or customize values (here are the defaults)
let prop = NgtProperties::dimension(3)?
let prop = NgtProperties::<f32>::dimension(3)?
.creation_edge_size(10)?
.search_edge_size(40)?
.object_type(NgtObject::Float)?
.distance_type(NgtDistance::L2)?;
```

Expand All @@ -57,7 +56,7 @@ use ngt::{NgtIndex, NgtProperties, EPSILON};
// Create a new index
let prop = NgtProperties::dimension(3)?;
let index = NgtIndex::create("target/path/to/index/dir", prop)?;
let index: NgtIndex<f32> = NgtIndex::create("target/path/to/index/dir", prop)?;
// Open an existing index
let mut index = NgtIndex::open("target/path/to/index/dir")?;
Expand All @@ -68,7 +67,7 @@ let vec2 = vec![4.0, 5.0, 6.0];
let id1 = index.insert(vec1)?;
let id2 = index.insert(vec2)?;
// Actually build the index (not yet persisted on disk)
// Build the index in RAM (not yet persisted on disk)
// This is required in order to be able to search vectors
index.build(2)?;
Expand All @@ -80,7 +79,7 @@ assert_eq!(index.get_vec(id1)?, vec![1.0, 2.0, 3.0]);
// Remove a vector and check that it is not present anymore
index.remove(id1)?;
let res = index.get_vec(id1);
assert!(matches!(res, Result::Err(_)));
assert!(res.is_err());
// Verify that now our search result is different
let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?;
Expand Down
2 changes: 1 addition & 1 deletion ngt-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ngt-sys"
version = "2.0.11"
version = "2.0.12"
authors = ["Romain Leroux <[email protected]>"]
edition = "2021"
links = "ngt"
Expand Down
2 changes: 1 addition & 1 deletion ngt-sys/NGT
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ pub const EPSILON: f32 = 0.1;

pub use crate::error::{Error, Result};
pub use crate::ngt::{optim, NgtDistance, NgtIndex, NgtObject, NgtProperties};

#[doc(inline)]
pub use half;
78 changes: 48 additions & 30 deletions src/ngt/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,28 @@ use std::ptr;
use ngt_sys as sys;
use scopeguard::defer;

use super::{NgtObject, NgtProperties};
use super::{NgtObject, NgtObjectType, NgtProperties};
use crate::error::{make_err, Error, Result};
use crate::{SearchResult, VecId};

#[derive(Debug)]
pub struct NgtIndex {
pub struct NgtIndex<T> {
pub(crate) path: CString,
pub(crate) prop: NgtProperties,
pub(crate) prop: NgtProperties<T>,
pub(crate) index: sys::NGTIndex,
ospace: sys::NGTObjectSpace,
ebuf: sys::NGTError,
}

unsafe impl Send for NgtIndex {}
unsafe impl Sync for NgtIndex {}
unsafe impl<T> Send for NgtIndex<T> {}
unsafe impl<T> Sync for NgtIndex<T> {}

impl NgtIndex {
impl<T> NgtIndex<T>
where
T: NgtObjectType,
{
/// Creates an empty ANNG index with the given [`NgtProperties`]().
pub fn create<P: AsRef<Path>>(path: P, prop: NgtProperties) -> Result<Self> {
pub fn create<P: AsRef<Path>>(path: P, prop: NgtProperties<T>) -> Result<Self> {
if cfg!(feature = "shared_mem") && path.as_ref().exists() {
Err(Error(format!("Path {:?} already exists", path.as_ref())))?
}
Expand Down Expand Up @@ -190,19 +193,33 @@ impl NgtIndex {
/// discoverable yet.
///
/// **The method [`build`](NgtIndex::build) must be called after inserting vectors**.
pub fn insert(&mut self, mut vec: Vec<f32>) -> Result<VecId> {
pub fn insert(&mut self, mut vec: Vec<T>) -> Result<VecId> {
unsafe {
let id = sys::ngt_insert_index_as_float(
self.index,
vec.as_mut_ptr(),
self.prop.dimension as u32,
self.ebuf,
);
let id = match self.prop.object_type {
NgtObject::Float => sys::ngt_insert_index_as_float(
self.index,
vec.as_mut_ptr() as *mut _,
self.prop.dimension as u32,
self.ebuf,
),
NgtObject::Uint8 => sys::ngt_insert_index_as_uint8(
self.index,
vec.as_mut_ptr() as *mut _,
self.prop.dimension as u32,
self.ebuf,
),
NgtObject::Float16 => sys::ngt_insert_index_as_float16(
self.index,
vec.as_mut_ptr() as *mut _,
self.prop.dimension as u32,
self.ebuf,
),
};
if id == 0 {
Err(make_err(self.ebuf))?
} else {
Ok(id)
}

Ok(id)
}
}

Expand Down Expand Up @@ -265,9 +282,9 @@ impl NgtIndex {
}

/// Get the specified vector.
pub fn get_vec(&self, id: VecId) -> Result<Vec<f32>> {
pub fn get_vec(&self, id: VecId) -> Result<Vec<T>> {
unsafe {
let results = match self.prop.object_type {
match self.prop.object_type {
NgtObject::Float => {
let results = sys::ngt_get_object_as_float(self.ospace, id, self.ebuf);
if results.is_null() {
Expand All @@ -281,7 +298,8 @@ impl NgtIndex {
);
let results = mem::ManuallyDrop::new(results);

results.iter().copied().collect::<Vec<_>>()
let results = results.iter().copied().collect::<Vec<_>>();
Ok(mem::transmute::<_, Vec<T>>(results))
}
NgtObject::Float16 => {
let results = sys::ngt_get_object(self.ospace, id, self.ebuf);
Expand All @@ -296,7 +314,8 @@ impl NgtIndex {
);
let results = mem::ManuallyDrop::new(results);

results.iter().map(|f16| f16.to_f32()).collect::<Vec<_>>()
let results = results.iter().copied().collect::<Vec<_>>();
Ok(mem::transmute::<_, Vec<T>>(results))
}
NgtObject::Uint8 => {
let results = sys::ngt_get_object_as_integer(self.ospace, id, self.ebuf);
Expand All @@ -311,11 +330,10 @@ impl NgtIndex {
);
let results = mem::ManuallyDrop::new(results);

results.iter().map(|&byte| byte as f32).collect::<Vec<_>>()
let results = results.iter().copied().collect::<Vec<_>>();
Ok(mem::transmute::<_, Vec<T>>(results))
}
};

Ok(results)
}
}
}

Expand All @@ -330,7 +348,7 @@ impl NgtIndex {
}
}

impl Drop for NgtIndex {
impl<T> Drop for NgtIndex<T> {
fn drop(&mut self) {
if !self.index.is_null() {
unsafe { sys::ngt_close_index(self.index) };
Expand Down Expand Up @@ -364,7 +382,7 @@ mod tests {
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::dimension(3)?;
let prop = NgtProperties::<f32>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Insert two vectors and get their id
Expand Down Expand Up @@ -431,7 +449,7 @@ mod tests {
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::dimension(3)?;
let prop = NgtProperties::<f32>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Batch insert 2 vectors, build and persist the index
Expand All @@ -456,7 +474,7 @@ mod tests {
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::dimension(3)?.object_type(NgtObject::Uint8)?;
let prop = NgtProperties::<u8>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Batch insert 2 vectors, build and persist the index
Expand All @@ -481,7 +499,7 @@ mod tests {
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::dimension(3)?.object_type(NgtObject::Float16)?;
let prop = NgtProperties::<half::f16>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Batch insert 2 vectors, build and persist the index
Expand All @@ -506,7 +524,7 @@ mod tests {
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::dimension(3)?;
let prop = NgtProperties::<f32>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

let vecs = vec![
Expand Down
15 changes: 7 additions & 8 deletions src/ngt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
//!
//! ```rust
//! # fn main() -> Result<(), ngt::Error> {
//! use ngt::{NgtProperties, NgtDistance, NgtObject};
//! use ngt::{NgtProperties, NgtDistance};
//!
//! // Defaut properties with vectors of dimension 3
//! let prop = NgtProperties::dimension(3)?;
//! let prop = NgtProperties::<f32>::dimension(3)?;
//!
//! // Or customize values (here are the defaults)
//! let prop = NgtProperties::dimension(3)?
//! let prop = NgtProperties::<f32>::dimension(3)?
//! .creation_edge_size(10)?
//! .search_edge_size(40)?
//! .object_type(NgtObject::Float)?
//! .distance_type(NgtDistance::L2)?;
//!
//! # Ok(())
Expand All @@ -26,7 +25,7 @@
//!
//! // Create a new index
//! let prop = NgtProperties::dimension(3)?;
//! let index = NgtIndex::create("target/path/to/index/dir", prop)?;
//! let index: NgtIndex<f32> = NgtIndex::create("target/path/to/index/dir", prop)?;
//!
//! // Open an existing index
//! let mut index = NgtIndex::open("target/path/to/index/dir")?;
Expand All @@ -37,7 +36,7 @@
//! let id1 = index.insert(vec1)?;
//! let id2 = index.insert(vec2)?;
//!
//! // Actually build the index (not yet persisted on disk)
//! // Build the index in RAM (not yet persisted on disk)
//! // This is required in order to be able to search vectors
//! index.build(2)?;
//!
Expand All @@ -49,7 +48,7 @@
//! // Remove a vector and check that it is not present anymore
//! index.remove(id1)?;
//! let res = index.get_vec(id1);
//! assert!(matches!(res, Result::Err(_)));
//! assert!(res.is_err());
//!
//! // Verify that now our search result is different
//! let res = index.search(&vec![1.1, 2.1, 3.1], 1, EPSILON)?;
Expand All @@ -69,4 +68,4 @@ pub mod optim;
mod properties;

pub use self::index::NgtIndex;
pub use self::properties::{NgtDistance, NgtObject, NgtProperties};
pub use self::properties::{NgtDistance, NgtObject, NgtObjectType, NgtProperties};
Loading

0 comments on commit 8dd09ac

Please sign in to comment.