Skip to content

Commit

Permalink
Add query phase searcher and basic tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jun 28, 2023
1 parent 2de63dd commit 3955bfc
Show file tree
Hide file tree
Showing 13 changed files with 793 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.client.Client;
Expand All @@ -26,13 +27,15 @@
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

Expand Down Expand Up @@ -74,4 +77,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else {
if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
// regular boost functionality is not supported, user should use score normalization methods to manipulate with scores
if (boost != DEFAULT_BOOST) {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD));
throw new ParsingException(
parser.getTokenLocation(),
String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD)
);
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.search;

import java.util.Arrays;
import java.util.List;

import lombok.Getter;
import lombok.ToString;
Expand All @@ -21,23 +22,23 @@
public class CompoundTopDocs extends TopDocs {

@Getter
private TopDocs[] compoundTopDocs;
private List<TopDocs> compoundTopDocs;

public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) {
super(totalHits, scoreDocs);
}

public CompoundTopDocs(TotalHits totalHits, TopDocs[] docs) {
public CompoundTopDocs(TotalHits totalHits, List<TopDocs> docs) {
// we pass clone of score docs from the sub-query that has most hits
super(totalHits, cloneLargestScoreDocs(docs));
this.compoundTopDocs = docs;
}

private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = null;
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
Expand All @@ -48,9 +49,6 @@ private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
maxScoreDocs = topDoc.scoreDocs;
}
}
if (maxScoreDocs == null) {
return null;
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.neuralsearch.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import lombok.Getter;
Expand All @@ -31,9 +33,7 @@
public class HybridTopScoreDocCollector implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private float minCompetitiveScore;
private final HitsThresholdChecker hitsThresholdChecker;
private ScoreDoc pqTop;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
Expand All @@ -48,15 +48,13 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
minCompetitiveScore = 0f;

return new TopScoreDocCollector.ScorerLeafCollector() {
HybridQueryScorer compoundQueryScorer;

@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
updateMinCompetitiveScore(scorer);
compoundQueryScorer = (HybridQueryScorer) scorer;
}

Expand Down Expand Up @@ -93,29 +91,19 @@ public ScoreMode scoreMode() {
return hitsThresholdChecker.scoreMode();
}

protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (hitsThresholdChecker.isThresholdReached() && pqTop != null && pqTop.score != Float.NEGATIVE_INFINITY) { // -Infinity is the
// boundary score
// we have multiple identical doc id and collect in doc id order, we need next float
float localMinScore = Math.nextUp(pqTop.score);
if (localMinScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(localMinScore);
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
minCompetitiveScore = localMinScore;
}
}
}

/**
* Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query
* @return
*/
public TopDocs[] topDocs() {
TopDocs[] topDocs = new TopDocs[compoundScores.length];
public List<TopDocs> topDocs() {
List<TopDocs> topDocs;
if (compoundScores == null) {
return new ArrayList<>();
}
topDocs = new ArrayList(compoundScores.length);
for (int i = 0; i < compoundScores.length; i++) {
int qTopSize = totalHits[i];
TopDocs topDocsPerQuery = topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize);
topDocs[i] = topDocsPerQuery;
topDocs.add(topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize));
}
return topDocs;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.search.query;

import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;

import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TopDocsCollectorContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

/**
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the
* upstream standard implementation of searcher is called.
*/
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

private Function<List<TopDocs>, TotalHits> totalHitsSupplier;
private Function<List<TopDocs>, Float> maxScoreSupplier;
protected SortAndFormats sortAndFormats;

public boolean searchWith(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
log.debug(String.format(Locale.ROOT, "searching with custom doc collector, shard %s", searchContext.shardTarget().getShardId()));

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);

final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
if (searchContext.size() == 0) {
final TotalHitCountCollector collector = new TotalHitCountCollector();
searcher.search(query, collector);
return false;
}
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean rescore = !searchContext.rescore().isEmpty();
if (rescore) {
assert searchContext.sort() == null;
for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
}

final QuerySearchResult queryResult = searchContext.queryResult();

final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);
totalHitsSupplier = topDocs -> {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
};
maxScoreSupplier = topDocs -> {
if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
};
sortAndFormats = searchContext.sort();

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}

setTopDocsInQueryResult(queryResult, collector);

return rescore;
}

void setTopDocsInQueryResult(final QuerySearchResult queryResult, final HybridTopScoreDocCollector collector) {
final List<TopDocs> topDocs = collector.topDocs();
float maxScore = maxScoreSupplier.apply(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(totalHitsSupplier.apply(topDocs), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats());
}

private DocValueFormat[] getSortValueFormats() {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

@Before
public void setupSettings() {
if (isUpdateClusterSettings()) {
updateClusterSettings();
}
}

protected void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
Expand Down Expand Up @@ -514,4 +520,8 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

public boolean isUpdateClusterSettings() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
package org.opensearch.neuralsearch.plugin;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.Mockito.mock;

public class NeuralSearchTests extends OpenSearchTestCase {

public void testQuerySpecs() {
Expand All @@ -23,4 +31,21 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
}

public void testProcessors() {
NeuralSearch plugin = new NeuralSearch();
Processor.Parameters processorParams = mock(Processor.Parameters.class);
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}
}
Loading

0 comments on commit 3955bfc

Please sign in to comment.