Skip to content

Commit

Permalink
Delete filter
Browse files Browse the repository at this point in the history
  • Loading branch information
shatejas committed Jul 17, 2024
1 parent ee4b37b commit 1e33910
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 165 deletions.
11 changes: 11 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ class IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);

virtual jobjectArray searchIndex(
JNIUtilInterface * jniUtil,
JNIEnv * env,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jint kJ,
jobject methodParamsJ,
jobject acceptedDocs,
jintArray parentIdsJ);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
Expand Down
9 changes: 9 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray);


/*
* Class: org_opensearch_knn_jni_FaissService
* Method: searchIndex
* Signature: (J[FILjava/util/Map;Lorg/apache/lucene/util/Bits;[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_searchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jobject, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
356 changes: 231 additions & 125 deletions jni/src/faiss_index_service.cpp

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) {
this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:<init>"] = env->GetMethodID(tempLocalClassRef, "<init>", "(IF)V");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("Lorg/apache/lucene/util/Bits;");
this->cachedClasses["Lorg/apache/lucene/util/Bits;"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["Lorg/apache/lucene/util/Bits;:get"] = env->GetMethodID(tempLocalClassRef, "get", "(I)Z");
this->cachedMethods["Lorg/apache/lucene/util/Bits;:length"] = env->GetMethodID(tempLocalClassRef, "length", "()I");
env->DeleteLocalRef(tempLocalClassRef);
}

void knn_jni::JNIUtil::Uninitialize(JNIEnv* env) {
Expand Down
14 changes: 14 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd

}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_searchIndex
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jobject acceptedDocs, jintArray parentIdsJ) {

try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
return indexService.searchIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, acceptedDocs, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;

}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

Expand Down
66 changes: 26 additions & 40 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.KNNConstants;
Expand Down Expand Up @@ -108,28 +107,36 @@ public Explanation explain(LeafReaderContext context, int doc) {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final Bits liveDocs = context.reader().getLiveDocs();
final int maxDoc = context.reader().maxDoc();
if (filterWeight == null) {
// Always do approximate search if there are no filters
final Map<Integer, Float> annResults = doANNSearch(context, liveDocs);
return annResults == null ? KNNScorer.emptyScorer(this) : convertSearchResponseToScorer(annResults);
}

final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
if (filterWeight != null && cardinality == 0) {
Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
// If scorer is not present, there will be no top k that will match; return an empty scorer here
// and indirectly an empty response
return KNNScorer.emptyScorer(this);
}
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();

final BitSet filterBitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
int cardinality = filterBitSet.cardinality();

/*
* The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (filterWeight != null && canDoExactSearch(cardinality)) {
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();
if (canDoExactSearch(cardinality)) {
docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet, cardinality));
} else {
Map<Integer, Float> annResults = doANNSearch(context, filterBitSet, cardinality);
Map<Integer, Float> annResults = doANNSearch(context, filterBitSet);
if (annResults == null) {
return null;
return KNNScorer.emptyScorer(this);
}
if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) {
log.debug(
Expand All @@ -149,22 +156,6 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return convertSearchResponseToScorer(docIdsToScoreMap);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
}

final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return new FixedBitSet(0);
}

return createBitSet(scorer.iterator(), liveDocs, maxDoc);
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
Expand Down Expand Up @@ -201,13 +192,11 @@ private int[] bitSetToIntArray(final BitSet bitSet) {
return intArray;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality)
throws IOException {
private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final Bits acceptedDocs) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return null;
Expand Down Expand Up @@ -268,9 +257,6 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
}

// From cardinality select different filterIds type
FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality);
long[] filterIds = filterIdsSelector.getFilterIds();
FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType();
// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();
try {
Expand All @@ -286,22 +272,22 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
knnQuery.getK(),
knnQuery.getMethodParameters(),
knnEngine,
filterIds,
filterType.getValue(),
null,
0,
parentIds
);
} else {
results = JNIService.queryIndex(
results = JNIService.searchIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnQuery.getMethodParameters(),
knnEngine,
filterIds,
filterType.getValue(),
acceptedDocs,
parentIds
);
}

} else {
results = JNIService.radiusQueryIndex(
indexAllocation.getMemoryAddress(),
Expand All @@ -310,8 +296,8 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
knnQuery.getMethodParameters(),
knnEngine,
knnQuery.getContext().getMaxResultWindow(),
filterIds,
filterType.getValue(),
null,
0,
parentIds
);
}
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.jni;

import org.apache.lucene.util.Bits;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -202,6 +203,15 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter(
int[] parentIds
);

public static native KNNQueryResult[] searchIndex(
long indexPointer,
float[] queryVector,
int k,
Map<String, ?> methodParameters,
Bits acceptedDocs,
int[] parentIds
);

/**
* Query a binary index with filter
*
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.jni;

import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.util.Bits;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.query.KNNQueryResult;
Expand Down Expand Up @@ -167,6 +168,21 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr,
);
}

public static KNNQueryResult[] searchIndex(
long indexPointer,
float[] queryVector,
int k,
@Nullable Map<String, ?> methodParameters,
KNNEngine knnEngine,
Bits acceptedDocs,
int[] parentIds
) {
if (KNNEngine.FAISS == knnEngine) {
return FaissService.searchIndex(indexPointer, queryVector, k, methodParameters, acceptedDocs, parentIds);
}
throw new IllegalArgumentException();
}

/**
* Query an index
*
Expand Down

0 comments on commit 1e33910

Please sign in to comment.