-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add query phase searcher and basic tests Signed-off-by: Martin Gaievski <[email protected]> --------- Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
2de63dd
commit 7dda0c5
Showing
13 changed files
with
811 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
148 changes: 148 additions & 0 deletions
148
src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.search.query; | ||
|
||
import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; | ||
|
||
import java.io.IOException; | ||
import java.util.LinkedList; | ||
import java.util.List; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
import org.apache.lucene.index.IndexReader; | ||
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
import org.apache.lucene.search.TotalHitCountCollector; | ||
import org.apache.lucene.search.TotalHits; | ||
import org.opensearch.common.lucene.search.TopDocsAndMaxScore; | ||
import org.opensearch.neuralsearch.query.HybridQuery; | ||
import org.opensearch.neuralsearch.search.CompoundTopDocs; | ||
import org.opensearch.neuralsearch.search.HitsThresholdChecker; | ||
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; | ||
import org.opensearch.search.DocValueFormat; | ||
import org.opensearch.search.internal.ContextIndexSearcher; | ||
import org.opensearch.search.internal.SearchContext; | ||
import org.opensearch.search.query.QueryCollectorContext; | ||
import org.opensearch.search.query.QueryPhase; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
import org.opensearch.search.query.TopDocsCollectorContext; | ||
import org.opensearch.search.rescore.RescoreContext; | ||
import org.opensearch.search.sort.SortAndFormats; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
|
||
/** | ||
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the | ||
* upstream standard implementation of searcher is called. | ||
*/ | ||
@Log4j2 | ||
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher { | ||
|
||
public boolean searchWith( | ||
final SearchContext searchContext, | ||
final ContextIndexSearcher searcher, | ||
final Query query, | ||
final LinkedList<QueryCollectorContext> collectors, | ||
final boolean hasFilterCollector, | ||
final boolean hasTimeout | ||
) throws IOException { | ||
if (query instanceof HybridQuery) { | ||
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); | ||
} | ||
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); | ||
} | ||
|
||
@VisibleForTesting | ||
protected boolean searchWithCollector( | ||
final SearchContext searchContext, | ||
final ContextIndexSearcher searcher, | ||
final Query query, | ||
final LinkedList<QueryCollectorContext> collectors, | ||
final boolean hasFilterCollector, | ||
final boolean hasTimeout | ||
) throws IOException { | ||
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); | ||
|
||
final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); | ||
collectors.addFirst(topDocsFactory); | ||
if (searchContext.size() == 0) { | ||
final TotalHitCountCollector collector = new TotalHitCountCollector(); | ||
searcher.search(query, collector); | ||
return false; | ||
} | ||
final IndexReader reader = searchContext.searcher().getIndexReader(); | ||
int totalNumDocs = Math.max(0, reader.numDocs()); | ||
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); | ||
final boolean shouldRescore = !searchContext.rescore().isEmpty(); | ||
if (shouldRescore) { | ||
for (RescoreContext rescoreContext : searchContext.rescore()) { | ||
numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); | ||
} | ||
} | ||
|
||
final QuerySearchResult queryResult = searchContext.queryResult(); | ||
|
||
final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( | ||
numDocs, | ||
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) | ||
); | ||
|
||
searcher.search(query, collector); | ||
|
||
if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { | ||
queryResult.terminatedEarly(false); | ||
} | ||
|
||
setTopDocsInQueryResult(queryResult, collector, searchContext); | ||
|
||
return shouldRescore; | ||
} | ||
|
||
private void setTopDocsInQueryResult( | ||
final QuerySearchResult queryResult, | ||
final HybridTopScoreDocCollector collector, | ||
final SearchContext searchContext | ||
) { | ||
final List<TopDocs> topDocs = collector.topDocs(); | ||
final float maxScore = getMaxScore(topDocs); | ||
final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs); | ||
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); | ||
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); | ||
} | ||
|
||
private TotalHits getTotalHits(final SearchContext searchContext, final List<TopDocs> topDocs) { | ||
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); | ||
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED | ||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO | ||
: TotalHits.Relation.EQUAL_TO; | ||
if (topDocs == null || topDocs.size() == 0) { | ||
return new TotalHits(0, relation); | ||
} | ||
long maxTotalHits = topDocs.get(0).totalHits.value; | ||
for (TopDocs topDoc : topDocs) { | ||
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); | ||
} | ||
return new TotalHits(maxTotalHits, relation); | ||
} | ||
|
||
private float getMaxScore(final List<TopDocs> topDocs) { | ||
if (topDocs.size() == 0) { | ||
return Float.NaN; | ||
} else { | ||
return topDocs.stream() | ||
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) | ||
.map(scoreDoc -> scoreDoc.score) | ||
.max(Float::compare) | ||
.get(); | ||
} | ||
} | ||
|
||
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { | ||
return sortAndFormats == null ? null : sortAndFormats.formats; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.