Skip to content

Commit

Permalink
Add Query phase searcher (#204)
Browse files Browse the repository at this point in the history
* Add query phase searcher and basic tests

Signed-off-by: Martin Gaievski <[email protected]>

---------

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jun 29, 2023
1 parent 2de63dd commit 7dda0c5
Show file tree
Hide file tree
Showing 13 changed files with 811 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.client.Client;
Expand All @@ -26,13 +27,15 @@
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

Expand Down Expand Up @@ -74,4 +77,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else {
if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
// regular boost functionality is not supported, user should use score normalization methods to manipulate with scores
if (boost != DEFAULT_BOOST) {
log.error("[{}] query does not support provided value {} for [{}]", NAME, boost, BOOST_FIELD);
throw new ParsingException(parser.getTokenLocation(), "[{}] query does not support [{}]", NAME, BOOST_FIELD);
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.search;

import java.util.Arrays;
import java.util.List;

import lombok.Getter;
import lombok.ToString;
Expand All @@ -21,23 +22,23 @@
public class CompoundTopDocs extends TopDocs {

@Getter
private TopDocs[] compoundTopDocs;
private List<TopDocs> compoundTopDocs;

public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) {
super(totalHits, scoreDocs);
}

public CompoundTopDocs(TotalHits totalHits, TopDocs[] docs) {
public CompoundTopDocs(TotalHits totalHits, List<TopDocs> docs) {
// we pass clone of score docs from the sub-query that has most hits
super(totalHits, cloneLargestScoreDocs(docs));
this.compoundTopDocs = docs;
}

private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = null;
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
Expand All @@ -48,9 +49,6 @@ private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
maxScoreDocs = topDoc.scoreDocs;
}
}
if (maxScoreDocs == null) {
return null;
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
package org.opensearch.neuralsearch.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
Expand All @@ -31,9 +35,7 @@
public class HybridTopScoreDocCollector implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private float minCompetitiveScore;
private final HitsThresholdChecker hitsThresholdChecker;
private ScoreDoc pqTop;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
Expand All @@ -48,15 +50,13 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
minCompetitiveScore = 0f;

return new TopScoreDocCollector.ScorerLeafCollector() {
HybridQueryScorer compoundQueryScorer;

@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
updateMinCompetitiveScore(scorer);
compoundQueryScorer = (HybridQueryScorer) scorer;
}

Expand Down Expand Up @@ -93,30 +93,17 @@ public ScoreMode scoreMode() {
return hitsThresholdChecker.scoreMode();
}

protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (hitsThresholdChecker.isThresholdReached() && pqTop != null && pqTop.score != Float.NEGATIVE_INFINITY) { // -Infinity is the
// boundary score
// we have multiple identical doc id and collect in doc id order, we need next float
float localMinScore = Math.nextUp(pqTop.score);
if (localMinScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(localMinScore);
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
minCompetitiveScore = localMinScore;
}
}
}

/**
* Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query
* @return
*/
public TopDocs[] topDocs() {
TopDocs[] topDocs = new TopDocs[compoundScores.length];
for (int i = 0; i < compoundScores.length; i++) {
int qTopSize = totalHits[i];
TopDocs topDocsPerQuery = topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize);
topDocs[i] = topDocsPerQuery;
public List<TopDocs> topDocs() {
if (compoundScores == null) {
return new ArrayList<>();
}
final List<TopDocs> topDocs = IntStream.range(0, compoundScores.length)
.mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i]))
.collect(Collectors.toList());
return topDocs;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.search.query;

import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;

import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TopDocsCollectorContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

import com.google.common.annotations.VisibleForTesting;

/**
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the
* upstream standard implementation of searcher is called.
*/
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

@VisibleForTesting
protected boolean searchWithCollector(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId());

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);
if (searchContext.size() == 0) {
final TotalHitCountCollector collector = new TotalHitCountCollector();
searcher.search(query, collector);
return false;
}
final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean shouldRescore = !searchContext.rescore().isEmpty();
if (shouldRescore) {
for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
}

final QuerySearchResult queryResult = searchContext.queryResult();

final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}

setTopDocsInQueryResult(queryResult, collector, searchContext);

return shouldRescore;
}

private void setTopDocsInQueryResult(
final QuerySearchResult queryResult,
final HybridTopScoreDocCollector collector,
final SearchContext searchContext
) {
final List<TopDocs> topDocs = collector.topDocs();
final float maxScore = getMaxScore(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
}

private TotalHits getTotalHits(final SearchContext searchContext, final List<TopDocs> topDocs) {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
}

private float getMaxScore(final List<TopDocs> topDocs) {
if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

@Before
public void setupSettings() {
if (isUpdateClusterSettings()) {
updateClusterSettings();
}
}

protected void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
Expand Down Expand Up @@ -514,4 +520,8 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

public boolean isUpdateClusterSettings() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchTests extends OpenSearchTestCase {
Expand All @@ -23,4 +31,21 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
}

public void testProcessors() {
NeuralSearch plugin = new NeuralSearch();
Processor.Parameters processorParams = mock(Processor.Parameters.class);
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}
}
Loading

0 comments on commit 7dda0c5

Please sign in to comment.