Skip to content

Commit

Permalink
Improve typed API
Browse files Browse the repository at this point in the history
  • Loading branch information
lerouxrgd committed Aug 16, 2023
1 parent 8dd09ac commit 182747c
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 84 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ jobs:
- shared_mem
- large_data
- quantized
- quantized,qg_optim
- large_data,shared_mem
- large_data,quantized
- static
- static,quantized
- static,quantized,qg_optim
- static,shared_mem,large_data
steps:
- uses: actions/checkout@v3
Expand Down
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"

[dependencies]
half = "2"
ngt-sys = { path = "ngt-sys", version = "2.0.12" }
ngt-sys = { path = "ngt-sys", version = "2.1.2" }
num_enum = "0.5"
scopeguard = "1"

Expand All @@ -22,8 +22,9 @@ rayon = "1"
tempfile = "3"

[features]
default = ["quantized"] # TODO: should not be default
default = ["quantized", "qg_optim"] # TODO: should not be default
static = ["ngt-sys/static"]
shared_mem = ["ngt-sys/shared_mem"]
large_data = ["ngt-sys/large_data"]
quantized = ["ngt-sys/quantized"]
qg_optim = ["quantized", "ngt-sys/qg_optim"]
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ This crate provides the following indexes:
* `QgIndex`: Quantized graph-based index[^2]
* `QbgIndex`: Quantized blob graph-based index

The quantized indexes are available through the `quantized` Cargo feature. Note that
Both quantized indexes are available through the `quantized` Cargo feature. Note that
they rely on `BLAS` and `LAPACK` which thus have to be installed locally. The CPU
running the code must also support `AVX2` instructions.
running the code must also support `AVX2` instructions. Furthermore, `QgIndex`
performances can be [improved][qg-optim] by using the `qg_optim` Cargo feature.

The `NgtIndex` default implementation is an ANNG, it can be optimized[^3] or converted
to an ONNG through the [`optim`][ngt-optim] module.
Expand Down Expand Up @@ -95,6 +96,7 @@ index.persist()?;
[ngt-largedata]: https://github.com/yahoojapan/NGT#large-scale-data-use
[ngt-ci]: https://github.com/lerouxrgd/ngt-rs/blob/master/.github/workflows/ci.yaml
[ngt-optim]: https://docs.rs/ngt/latest/ngt/optim/index.html
[qg-optim]: https://github.com/yahoojapan/NGT#build-parameters-1

[^1]: https://opensource.com/article/19/10/ngt-open-source-library
[^2]: https://medium.com/@masajiro.iwasaki/fusion-of-graph-based-indexing-and-product-quantization-for-ann-search-7d1f0336d0d0
Expand Down
3 changes: 2 additions & 1 deletion ngt-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ngt-sys"
version = "2.0.12"
version = "2.1.2"
authors = ["Romain Leroux <[email protected]>"]
edition = "2021"
links = "ngt"
Expand All @@ -20,3 +20,4 @@ static = ["dep:cpp_build"]
shared_mem = []
large_data = []
quantized = []
qg_optim = []
2 changes: 1 addition & 1 deletion ngt-sys/NGT
Submodule NGT updated 55 files
+3 −3 CMakeLists.txt
+31 −20 README-jp.md
+31 −19 README.md
+1 −1 VERSION
+1 −1 bin/ngt/CMakeLists.txt
+12 −12 lib/NGT/ArrayFile.h
+10 −8 lib/NGT/CMakeLists.txt
+244 −5 lib/NGT/Capi.cpp
+42 −0 lib/NGT/Capi.h
+65 −32 lib/NGT/Clustering.h
+20 −18 lib/NGT/Command.cpp
+135 −123 lib/NGT/Common.h
+76 −76 lib/NGT/Graph.cpp
+12 −12 lib/NGT/Graph.h
+11 −9 lib/NGT/GraphOptimizer.h
+44 −44 lib/NGT/GraphReconstructor.h
+1 −1 lib/NGT/HashBasedBooleanSet.h
+60 −60 lib/NGT/Index.cpp
+81 −50 lib/NGT/Index.h
+13 −13 lib/NGT/MmapManager.cpp
+20 −20 lib/NGT/MmapManager.h
+30 −30 lib/NGT/MmapManagerDefs.h
+22 −22 lib/NGT/MmapManagerImpl.hpp
+354 −44 lib/NGT/NGTQ/Capi.cpp
+72 −1 lib/NGT/NGTQ/Capi.h
+641 −0 lib/NGT/NGTQ/HierarchicalKmeans.cpp
+459 −434 lib/NGT/NGTQ/HierarchicalKmeans.h
+22 −22 lib/NGT/NGTQ/Matrix.h
+26 −16 lib/NGT/NGTQ/ObjectFile.h
+182 −86 lib/NGT/NGTQ/Optimizer.cpp
+42 −55 lib/NGT/NGTQ/Optimizer.h
+334 −176 lib/NGT/NGTQ/QbgCli.cpp
+19 −17 lib/NGT/NGTQ/QbgCli.h
+910 −544 lib/NGT/NGTQ/QuantizedBlobGraph.h
+11 −11 lib/NGT/NGTQ/QuantizedGraph.cpp
+50 −45 lib/NGT/NGTQ/QuantizedGraph.h
+857 −350 lib/NGT/NGTQ/Quantizer.h
+1 −1 lib/NGT/Node.cpp
+11 −11 lib/NGT/Node.h
+16 −16 lib/NGT/ObjectRepository.h
+98 −35 lib/NGT/ObjectSpace.h
+20 −16 lib/NGT/ObjectSpaceRepository.h
+54 −54 lib/NGT/Optimizer.h
+24 −24 lib/NGT/PrimitiveComparator.h
+4 −4 lib/NGT/SharedMemoryAllocator.cpp
+9 −9 lib/NGT/SharedMemoryAllocator.h
+3 −3 lib/NGT/Tree.cpp
+3 −3 lib/NGT/Tree.h
+17 −17 lib/NGT/Version.cpp
+1 −0 lib/NGT/defines.h.in
+3,202 −3,202 lib/NGT/half.hpp
+155 −127 python/src/ngtpy.cpp
+5 −5 samples/jaccard-sparse/jaccard-sparse.cpp
+2 −2 samples/qbg-capi/qbg-capi.cpp
+6 −3 samples/qg-capi/qg-capi.cpp
5 changes: 4 additions & 1 deletion ngt-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ fn main() {
if env::var("CARGO_FEATURE_QUANTIZED").is_err() {
config.define("NGT_QBG_DISABLED", "ON");
} else {
config.define("NGT_AVX2", "ON");
config.define("CMAKE_BUILD_TYPE", "Release");
if env::var("CARGO_FEATURE_QG_OPTIM").is_ok() {
config.define("NGTQG_NO_ROTATION", "ON");
config.define("NGTQG_ZERO_GLOBAL", "ON");
}
}
let dst = config.build();

Expand Down
71 changes: 53 additions & 18 deletions src/ngt/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ where
let rsize = sys::ngt_get_result_size(results, self.ebuf);
let mut ret = Vec::with_capacity(rsize as usize);

for i in 0..rsize as u32 {
for i in 0..rsize {
let d = sys::ngt_get_result(results, i, self.ebuf);
if d.id == 0 && d.distance == 0.0 {
Err(make_err(self.ebuf))?
Expand Down Expand Up @@ -173,7 +173,7 @@ where
let rsize = sys::ngt_get_result_size(results, self.ebuf);
let mut ret = Vec::with_capacity(rsize as usize);

for i in 0..rsize as u32 {
for i in 0..rsize {
let d = sys::ngt_get_result(results, i, self.ebuf);
if d.id == 0 && d.distance == 0.0 {
Err(make_err(self.ebuf))?
Expand All @@ -198,13 +198,13 @@ where
let id = match self.prop.object_type {
NgtObject::Float => sys::ngt_insert_index_as_float(
self.index,
vec.as_mut_ptr() as *mut _,
vec.as_mut_ptr() as *mut f32,
self.prop.dimension as u32,
self.ebuf,
),
NgtObject::Uint8 => sys::ngt_insert_index_as_uint8(
self.index,
vec.as_mut_ptr() as *mut _,
vec.as_mut_ptr() as *mut u8,
self.prop.dimension as u32,
self.ebuf,
),
Expand All @@ -227,7 +227,7 @@ where
/// discoverable yet.
///
/// **The method [`build`](NgtIndex::build) must be called after inserting vectors**.
pub fn insert_batch(&mut self, batch: Vec<Vec<f32>>) -> Result<()> {
pub fn insert_batch(&mut self, batch: Vec<Vec<T>>) -> Result<()> {
let batch_size = u32::try_from(batch.len())?;

if batch_size > 0 {
Expand All @@ -243,9 +243,38 @@ where
}

unsafe {
let mut batch = batch.into_iter().flatten().collect::<Vec<f32>>();
if !sys::ngt_batch_append_index(self.index, batch.as_mut_ptr(), batch_size, self.ebuf) {
Err(make_err(self.ebuf))?
let mut batch = batch.into_iter().flatten().collect::<Vec<T>>();
match self.prop.object_type {
NgtObject::Float => {
if !sys::ngt_batch_append_index(
self.index,
batch.as_mut_ptr() as *mut f32,
batch_size,
self.ebuf,
) {
Err(make_err(self.ebuf))?
}
}
NgtObject::Uint8 => {
if !sys::ngt_batch_append_index_as_uint8(
self.index,
batch.as_mut_ptr() as *mut u8,
batch_size,
self.ebuf,
) {
Err(make_err(self.ebuf))?
}
}
NgtObject::Float16 => {
if !sys::ngt_batch_append_index_as_float16(
self.index,
batch.as_mut_ptr() as *mut _,
batch_size,
self.ebuf,
) {
Err(make_err(self.ebuf))?
}
}
}
Ok(())
}
Expand Down Expand Up @@ -367,14 +396,15 @@ mod tests {
use std::iter;
use std::result::Result as StdResult;

use half::f16;
use rayon::prelude::*;
use tempfile::tempdir;

use super::*;
use crate::EPSILON;

#[test]
fn test_basics() -> StdResult<(), Box<dyn StdError>> {
fn test_ngt_f32_basics() -> StdResult<(), Box<dyn StdError>> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
Expand Down Expand Up @@ -441,7 +471,7 @@ mod tests {
}

#[test]
fn test_batch() -> StdResult<(), Box<dyn StdError>> {
fn test_ngt_batch() -> StdResult<(), Box<dyn StdError>> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
Expand All @@ -466,7 +496,7 @@ mod tests {
}

#[test]
fn test_u8() -> StdResult<(), Box<dyn StdError>> {
fn test_ngt_u8() -> StdResult<(), Box<dyn StdError>> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
Expand All @@ -477,8 +507,9 @@ mod tests {
let prop = NgtProperties::<u8>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Batch insert 2 vectors, build and persist the index
index.insert_batch(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]])?;
// Insert 3 vectors, build and persist the index
index.insert_batch(vec![vec![1, 2, 3], vec![4, 5, 6]])?;
index.insert(vec![7, 8, 9])?;
index.build(2)?;
index.persist()?;

Expand All @@ -491,19 +522,23 @@ mod tests {
}

#[test]
fn test_f16() -> StdResult<(), Box<dyn StdError>> {
fn test_ngt_f16() -> StdResult<(), Box<dyn StdError>> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
std::fs::remove_dir(dir.path())?;
}

// Create an index for vectors of dimension 3
let prop = NgtProperties::<half::f16>::dimension(3)?;
let prop = NgtProperties::<f16>::dimension(3)?;
let mut index = NgtIndex::create(dir.path(), prop)?;

// Batch insert 2 vectors, build and persist the index
index.insert_batch(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]])?;
// Insert 3 vectors, build and persist the index
index.insert_batch(vec![
vec![1.0, 2.0, 3.0].into_iter().map(f16::from_f32).collect(),
vec![4.0, 5.0, 6.0].into_iter().map(f16::from_f32).collect(),
])?;
index.insert(vec![7.0, 8.0, 9.0].into_iter().map(f16::from_f32).collect())?;
index.build(2)?;
index.persist()?;

Expand All @@ -516,7 +551,7 @@ mod tests {
}

#[test]
fn test_multithreaded() -> StdResult<(), Box<dyn StdError>> {
fn test_ngt_multithreaded() -> StdResult<(), Box<dyn StdError>> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
Expand Down
Loading

0 comments on commit 182747c

Please sign in to comment.