From 1e33910de96e913761e2ed06e37b7a542f610fb5 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Wed, 17 Jul 2024 07:52:09 -0700 Subject: [PATCH] Delete filter --- jni/include/faiss_index_service.h | 11 + .../org_opensearch_knn_jni_FaissService.h | 9 + jni/src/faiss_index_service.cpp | 356 ++++++++++++------ jni/src/jni_util.cpp | 6 + .../org_opensearch_knn_jni_FaissService.cpp | 14 + .../opensearch/knn/index/query/KNNWeight.java | 66 ++-- .../org/opensearch/knn/jni/FaissService.java | 10 + .../org/opensearch/knn/jni/JNIService.java | 16 + 8 files changed, 323 insertions(+), 165 deletions(-) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9..8db4c3dd5 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -61,6 +61,17 @@ class IndexService { std::vector ids, std::string indexPath, std::unordered_map parameters); + + virtual jobjectArray searchIndex( + JNIUtilInterface * jniUtil, + JNIEnv * env, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jint kJ, + jobject methodParamsJ, + jobject acceptedDocs, + jintArray parentIdsJ); + virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff3..f20c1fa1a 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -107,6 +107,15 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter (JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray); + + /* + * Class: org_opensearch_knn_jni_FaissService + * Method: searchIndex + * Signature: (J[FILjava/util/Map;Lorg/apache/lucene/util/Bits;[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + */ + JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_searchIndex + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jobject, jintArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af..2f4fba14a 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -19,146 +19,252 @@ #include "faiss/IndexIDMap.h" #include "faiss/index_io.h" #include +#include "commons.h" +#include "faiss_util.h" #include #include #include #include +namespace faiss { + // Using jlong to do Bitmap selector, jlong[] equals to lucene FixedBitSet#bits + struct IDSelectorBits : IDSelector { + knn_jni::JNIUtilInterface *jni_util; + JNIEnv *env; + jclass clazz; + jmethodID getMethodId; + jobject acceptedDocs; + + /** Construct with a binary mask like Lucene FixedBitSet + * + * @param n size of the bitmap array + * @param bitmap id like Lucene FixedBitSet bits + */ + IDSelectorBits(JNIEnv *jni_env, knn_jni::JNIUtilInterface *jni_util, jobject acceptedDocs) : env(jni_env), + jni_util(jni_util), + clazz(jni_util->FindClass( + jni_env, "Lorg/apache/lucene/util/Bits;")), + getMethodId( + jni_util->FindMethod( + jni_env, "Lorg/apache/lucene/util/Bits;", + "get")), + acceptedDocs(acceptedDocs) {}; + + bool is_member(idx_t id) const final { + return env->CallBooleanMethod(acceptedDocs, getMethodId, id); + } + + ~IDSelectorBits() override { + } + }; +} + + namespace knn_jni { -namespace faiss_wrapper { - -template -void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, - const std::unordered_map& parametersCpp, INDEX * index) { - std::unordered_map::const_iterator value; - if (auto * indexIvf = dynamic_cast(index)) { - if ((value = parametersCpp.find(knn_jni::NPROBES)) != parametersCpp.end()) { - indexIvf->nprobe = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + namespace faiss_wrapper { + template + void SetExtraParameters(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, + const std::unordered_map ¶metersCpp, INDEX *index) { + std::unordered_map::const_iterator value; + if (auto *indexIvf = dynamic_cast(index)) { + if ((value = parametersCpp.find(knn_jni::NPROBES)) != parametersCpp.end()) { + indexIvf->nprobe = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::COARSE_QUANTIZER)) != parametersCpp.end() + && indexIvf->quantizer != nullptr) { + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, value->second); + SetExtraParameters(jniUtil, env, subParametersCpp, indexIvf->quantizer); + } + } + + if (auto *indexHnsw = dynamic_cast(index)) { + if ((value = parametersCpp.find(knn_jni::EF_CONSTRUCTION)) != parametersCpp.end()) { + indexHnsw->hnsw.efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::EF_SEARCH)) != parametersCpp.end()) { + indexHnsw->hnsw.efSearch = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + } + } + + IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) { } - if ((value = parametersCpp.find(knn_jni::COARSE_QUANTIZER)) != parametersCpp.end() - && indexIvf->quantizer != nullptr) { - auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, value->second); - SetExtraParameters(jniUtil, env, subParametersCpp, indexIvf->quantizer); + void IndexService::createIndex( + knn_jni::JNIUtilInterface *jniUtil, + JNIEnv *env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters + ) { + // Read vectors from memory address + auto *inputVectors = reinterpret_cast *>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + std::unique_ptr + indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters( + jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if (!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } - } - if (auto * indexHnsw = dynamic_cast(index)) { + jobjectArray IndexService::searchIndex( + JNIUtilInterface *jniUtil, + JNIEnv *env, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jint kJ, + jobject methodParamsJ, + jobject acceptedDocs, + jintArray parentIdsJ) { + auto *indexReader = reinterpret_cast(indexPointerJ); + std::unordered_map methodParams; + + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + + std::vector dis(kJ); + std::vector ids(kJ); + float *rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); + /* + Setting the omp_set_num_threads to 1 to make sure that no new OMP threads are getting created. + */ + omp_set_num_threads(1); + std::unique_ptr idSelector; + + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + if (acceptedDocs != nullptr) { + idSelector.reset(new faiss::IDSelectorBits(env, jniUtil, acceptedDocs)); + hnswParams.sel = idSelector.get(); + } + searchParameters = &hnswParams; + } - if ((value = parametersCpp.find(knn_jni::EF_CONSTRUCTION)) != parametersCpp.end()) { - indexHnsw->hnsw.efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + try { + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + throw; + } + + int resultSize = kJ; + auto it = std::find(ids.begin(), ids.end(), -1); + if (it != ids.end()) { + resultSize = it - ids.begin(); + } + + jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + for(int i = 0; i < resultSize; ++i) { + result = jniUtil->NewObject(env, resultClass, allArgs, ids[i], dis[i]); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + return results; } - if ((value = parametersCpp.find(knn_jni::EF_SEARCH)) != parametersCpp.end()) { - indexHnsw->hnsw.efSearch = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService( + std::move(faissMethods)) { } - } -} -IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} - -void IndexService::createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters - ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); - - // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value - int numVectors = (int) (inputVectors->size() / (uint64_t) dim); - if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); - } - - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); - } - - std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { - omp_set_num_threads(threadCount); - } - - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } - - // Add vectors - std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); - idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); - - // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); -} + void BinaryIndexService::createIndex( + knn_jni::JNIUtilInterface *jniUtil, + JNIEnv *env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters + ) { + // Read vectors from memory address + auto *inputVectors = reinterpret_cast *>(vectorsAddress); -BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} - -void BinaryIndexService::createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters - ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); - - if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); - } - // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value - int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); - if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); - } - - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); - } - - std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { - omp_set_num_threads(threadCount); - } - - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } - - // Add vectors - std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); - idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); - - // Write the index to disk - faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); -} + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + if (numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + std::unique_ptr indexWriter( + faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (threadCount != 0) { + omp_set_num_threads(threadCount); + } -} // namespace faiss_wrapper + // Add extra parameters that cant be configured with the index factory + SetExtraParameters( + jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if (!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + } + } // namespace faiss_wrapper } // namesapce knn_jni diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 919191596..410fc1497 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -65,6 +65,12 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) { this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef); this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:"] = env->GetMethodID(tempLocalClassRef, "", "(IF)V"); env->DeleteLocalRef(tempLocalClassRef); + + tempLocalClassRef = env->FindClass("Lorg/apache/lucene/util/Bits;"); + this->cachedClasses["Lorg/apache/lucene/util/Bits;"] = (jclass) env->NewGlobalRef(tempLocalClassRef); + this->cachedMethods["Lorg/apache/lucene/util/Bits;:get"] = env->GetMethodID(tempLocalClassRef, "get", "(I)Z"); + this->cachedMethods["Lorg/apache/lucene/util/Bits;:length"] = env->GetMethodID(tempLocalClassRef, "length", "()I"); + env->DeleteLocalRef(tempLocalClassRef); } void knn_jni::JNIUtil::Uninitialize(JNIEnv* env) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b034..853f0f37c 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -167,6 +167,20 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_searchIndex + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jobject acceptedDocs, jintArray parentIdsJ) { + + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return indexService.searchIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, acceptedDocs, parentIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; + +} + JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter (JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index a061e740e..3394b0d31 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -26,7 +26,6 @@ import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.DocIdSetBuilder; -import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.io.PathUtils; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.KNNConstants; @@ -108,28 +107,36 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + final Bits liveDocs = context.reader().getLiveDocs(); + final int maxDoc = context.reader().maxDoc(); + if (filterWeight == null) { + // Always do approximate search if there are no filters + final Map annResults = doANNSearch(context, liveDocs); + return annResults == null ? KNNScorer.emptyScorer(this) : convertSearchResponseToScorer(annResults); + } - final BitSet filterBitSet = getFilteredDocsBitSet(context); - int cardinality = filterBitSet.cardinality(); - // We don't need to go to JNI layer if no documents are found which satisfy the filters - // We should give this condition a deeper look that where it should be placed. For now I feel this is a good - // place, - if (filterWeight != null && cardinality == 0) { + Scorer scorer = filterWeight.scorer(context); + if (scorer == null) { + // If scorer is not present, there will be no top k that will match; return an empty scorer here + // and indirectly an empty response return KNNScorer.emptyScorer(this); } - final Map docIdsToScoreMap = new HashMap<>(); + + final BitSet filterBitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc); + int cardinality = filterBitSet.cardinality(); /* * The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - if (filterWeight != null && canDoExactSearch(cardinality)) { + final Map docIdsToScoreMap = new HashMap<>(); + if (canDoExactSearch(cardinality)) { docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet, cardinality)); } else { - Map annResults = doANNSearch(context, filterBitSet, cardinality); + Map annResults = doANNSearch(context, filterBitSet); if (annResults == null) { - return null; + return KNNScorer.emptyScorer(this); } if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) { log.debug( @@ -149,22 +156,6 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return convertSearchResponseToScorer(docIdsToScoreMap); } - private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { - if (this.filterWeight == null) { - return new FixedBitSet(0); - } - - final Bits liveDocs = ctx.reader().getLiveDocs(); - final int maxDoc = ctx.reader().maxDoc(); - - final Scorer scorer = filterWeight.scorer(ctx); - if (scorer == null) { - return new FixedBitSet(0); - } - - return createBitSet(scorer.iterator(), liveDocs, maxDoc); - } - private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { // If we already have a BitSet and no deletions, reuse the BitSet @@ -201,13 +192,11 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } - private Map doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality) - throws IOException { + private Map doANNSearch(final LeafReaderContext context, final Bits acceptedDocs) throws IOException { final SegmentReader reader = Lucene.segmentReader(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - if (fieldInfo == null) { log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); return null; @@ -268,9 +257,6 @@ private Map doANNSearch(final LeafReaderContext context, final B } // From cardinality select different filterIds type - FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality); - long[] filterIds = filterIdsSelector.getFilterIds(); - FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType(); // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); try { @@ -286,22 +272,22 @@ private Map doANNSearch(final LeafReaderContext context, final B knnQuery.getK(), knnQuery.getMethodParameters(), knnEngine, - filterIds, - filterType.getValue(), + null, + 0, parentIds ); } else { - results = JNIService.queryIndex( + results = JNIService.searchIndex( indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnQuery.getMethodParameters(), knnEngine, - filterIds, - filterType.getValue(), + acceptedDocs, parentIds ); } + } else { results = JNIService.radiusQueryIndex( indexAllocation.getMemoryAddress(), @@ -310,8 +296,8 @@ private Map doANNSearch(final LeafReaderContext context, final B knnQuery.getMethodParameters(), knnEngine, knnQuery.getContext().getMaxResultWindow(), - filterIds, - filterType.getValue(), + null, + 0, parentIds ); } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de90765..e3badc1f5 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -11,6 +11,7 @@ package org.opensearch.knn.jni; +import org.apache.lucene.util.Bits; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -202,6 +203,15 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( int[] parentIds ); + public static native KNNQueryResult[] searchIndex( + long indexPointer, + float[] queryVector, + int k, + Map methodParameters, + Bits acceptedDocs, + int[] parentIds + ); + /** * Query a binary index with filter * diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index cefd0af53..a51e72654 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.apache.lucene.util.Bits; import org.opensearch.common.Nullable; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.query.KNNQueryResult; @@ -167,6 +168,21 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, ); } + public static KNNQueryResult[] searchIndex( + long indexPointer, + float[] queryVector, + int k, + @Nullable Map methodParameters, + KNNEngine knnEngine, + Bits acceptedDocs, + int[] parentIds + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.searchIndex(indexPointer, queryVector, k, methodParameters, acceptedDocs, parentIds); + } + throw new IllegalArgumentException(); + } + /** * Query an index *