Skip to content

Commit

Permalink
Bug fix for total hits counts mismatch in hybrid query (#757)
Browse files Browse the repository at this point in the history
  • Loading branch information
vibrantvarun committed May 24, 2024
1 parent 940a7ea commit 70d0975
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 60 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
### Bug Fixes
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -80,16 +78,15 @@ private List<ScoreDoc> getCombinedScoreDocs(
final CompoundTopDocs compoundQueryTopDocs,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<Integer> sortedScores,
final int maxHits
final long maxHits
) {
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

List<ScoreDoc> scoreDocs = new ArrayList<>();
int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex;
for (int j = 0; j < maxHits && j < sortedScores.size(); j++) {
int docId = sortedScores.get(j);
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
scoreDocs.add(new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId));
}
return Arrays.stream(finalScoreDocs).collect(Collectors.toList());
return scoreDocs;
}

public Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
Expand Down Expand Up @@ -123,30 +120,16 @@ private void updateQueryTopDocsWithCombinedScores(
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<Integer> sortedScores
) {
// - count max number of hits among sub-queries
int maxHits = getMaxHits(topDocsPerSubQuery);
// - max number of hits will be the same which are passed from QueryPhase
long maxHits = compoundQueryTopDocs.getTotalHits().value;
// - update query search results with normalized scores
compoundQueryTopDocs.setScoreDocs(
getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits)
);
compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits));
}

/**
* Get max hits as number of unique doc ids from results of all sub-queries
* @param topDocsPerSubQuery list of topDocs objects for one shard
* @return number of unique doc ids
*/
protected int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
Set<Integer> docIds = topDocsPerSubQuery.stream()
.filter(topDocs -> Objects.nonNull(topDocs.scoreDocs))
.flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs))
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toSet());
return docIds.size();
}

private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final long maxHits) {
TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.Getter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.HitQueue;
Expand All @@ -35,7 +34,9 @@ public class HybridTopScoreDocCollector implements Collector {
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
@Getter
private int totalHits;
private int[] collectedHitsPerSubQuery;
private final int numOfHits;
private PriorityQueue<ScoreDoc>[] compoundScores;

Expand Down Expand Up @@ -94,23 +95,24 @@ public void collect(int doc) throws IOException {
if (Objects.isNull(compoundQueryScorer)) {
throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");
}

float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
compoundScores = new PriorityQueue[subScoresByQuery.length];
for (int i = 0; i < subScoresByQuery.length; i++) {
compoundScores[i] = new HitQueue(numOfHits, false);
}
totalHits = new int[subScoresByQuery.length];
collectedHitsPerSubQuery = new int[subScoresByQuery.length];
}
// Increment total hit count which represents unique doc found on the shard
totalHits++;
for (int i = 0; i < subScoresByQuery.length; i++) {
float score = subScoresByQuery[i];
// if score is 0.0 there is no hits for that sub-query
if (score == 0) {
continue;
}
totalHits[i]++;
collectedHitsPerSubQuery[i]++;
PriorityQueue<ScoreDoc> pq = compoundScores[i];
ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
// this way we're inserting into heap and do nothing else unless we reach the capacity
Expand All @@ -134,9 +136,17 @@ public List<TopDocs> topDocs() {
if (compoundScores == null) {
return new ArrayList<>();
}
final List<TopDocs> topDocs = IntStream.range(0, compoundScores.length)
.mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i]))
.collect(Collectors.toList());
final List<TopDocs> topDocs = new ArrayList<>();
for (int i = 0; i < compoundScores.length; i++) {
topDocs.add(
topDocsPerQuery(
0,
Math.min(collectedHitsPerSubQuery[i], compoundScores[i].size()),
compoundScores[i],
collectedHitsPerSubQuery[i]
)
);
}
return topDocs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
Expand Down Expand Up @@ -145,7 +142,10 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs);
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
Expand Down Expand Up @@ -196,24 +196,19 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final boolean isSingleShard) {
private TotalHits getTotalHits(
int trackTotalHitsUpTo,
final List<TopDocs> topDocs,
final boolean isSingleShard,
final long maxTotalHits
) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}

List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();

return new TotalHits(maxTotalHits, relation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc

final List<CompoundTopDocs> queryTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new TotalHits(5, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
Expand Down
29 changes: 29 additions & 0 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() {
}
}

@SneakyThrows
public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);
Map<String, Object> searchResponseAsMap = search(
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderNeuralThenTerm,
null,
1,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertEquals(1, getHitCount(searchResponseAsMap));
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(3, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

/**
* Tests complex query with multiple nested sub-queries, where some sub-queries are same
* {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
*/
package org.opensearch.neuralsearch.search;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -13,6 +23,9 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -24,16 +37,6 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.index.mapper.TextFieldMapper;
Expand All @@ -50,6 +53,7 @@ public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {
private static final String TEST_QUERY_TEXT = "greeting";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final int NUM_DOCS = 4;
private static final int NUM_HITS = 1;
private static final int TOTAL_HITS_UP_TO = 1000;

private static final int DOC_ID_1 = RandomUtils.nextInt(0, 100_000);
Expand Down Expand Up @@ -493,4 +497,71 @@ public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful()
reader.close();
directory.close();
}

@SneakyThrows
public void testTotalHitsCountValidation_whenTotalHitsCollectedAtTopLevelInCollector_thenSuccessful() {
final Directory directory = newDirectory();
final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_4, FIELD_4_VALUE, ft));
w.commit();

DirectoryReader reader = DirectoryReader.open(w);

LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);

HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector(
NUM_HITS,
new HitsThresholdChecker(Integer.MAX_VALUE)
);
LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext);
assertNotNull(leafCollector);

Weight weight = mock(Weight.class);
int[] docIdsForQuery1 = new int[] { DOC_ID_1, DOC_ID_2 };
Arrays.sort(docIdsForQuery1);
int[] docIdsForQuery2 = new int[] { DOC_ID_3, DOC_ID_4 };
Arrays.sort(docIdsForQuery2);
final List<Float> scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList());
HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
Arrays.asList(
scorer(docIdsForQuery1, scores, fakeWeight(new MatchAllDocsQuery())),
scorer(docIdsForQuery2, scores, fakeWeight(new MatchAllDocsQuery()))
)
);

leafCollector.setScorer(hybridQueryScorer);
DocIdSetIterator iterator = hybridQueryScorer.iterator();
int nextDoc = iterator.nextDoc();
while (nextDoc != NO_MORE_DOCS) {
leafCollector.collect(nextDoc);
nextDoc = iterator.nextDoc();
}

List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
long totalHits = hybridTopScoreDocCollector.getTotalHits();
List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();
assertEquals(4, totalHits);
// Total unique docs on the shard will be 2 as per 1 per sub-query
assertEquals(2, maxTotalHits);
w.close();
reader.close();
directory.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ private static IndexMetadata getIndexMetadata() {
RemoteStoreEnums.PathType.NAME,
HASHED_PREFIX.name(),
RemoteStoreEnums.PathHashAlgorithm.NAME,
RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name()
RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(),
IndexMetadata.TRANSLOG_METADATA_KEY,
"false"
);
Settings idxSettings = Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
Expand Down

0 comments on commit 70d0975

Please sign in to comment.