diff --git a/CHANGELOG.md b/CHANGELOG.md index 528231b07..8dcdc721b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 8846f6977..01d271cdd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -12,7 +12,6 @@ import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -77,12 +76,12 @@ public String toString(String field) { /** * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, * until the rewritten query is the same as the original query. - * @param reader + * @param indexSearcher * @return * @throws IOException */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (subQueries.isEmpty()) { return new MatchNoDocsQuery("empty HybridQuery"); } @@ -90,7 +89,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenSubQueries = new ArrayList<>(); for (Query subQuery : subQueries) { - Query rewrittenSub = subQuery.rewrite(reader); + Query rewrittenSub = subQuery.rewrite(indexSearcher); /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses perform better. For hybrid query we need to track progress of re-write for all sub-queries */ @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new HybridQuery(rewrittenSubQueries); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index aa4242c2e..60d9fd639 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; @@ -54,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { List queries = queryBuilders.stream().map(qb -> { try { - return Rewriteable.rewrite(qb, context).toQuery(context); + return qb.rewrite(context).toQuery(context); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 5abfd0b5e..f31d0abd9 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -18,10 +19,13 @@ import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import lombok.Getter; +import org.apache.lucene.util.PriorityQueue; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing @@ -40,12 +44,56 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; - public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + private final DocIdSetIterator approximation; + private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; + private final TwoPhase twoPhase; + + public HybridQueryScorer(final Weight weight, final List subScorers) throws IOException { + this(weight, subScorers, ScoreMode.TOP_SCORES); + } + + HybridQueryScorer(final Weight weight, final List subScorers, final ScoreMode scoreMode) throws IOException { super(weight); this.subScorers = Collections.unmodifiableList(subScorers); subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); this.subScorersPQ = initializeSubScorersPQ(); + boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; + + this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ); + if (scoreMode == ScoreMode.TOP_SCORES) { + this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers); + } else { + this.disjunctionBlockPropagator = null; + } + + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : subScorersPQ) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + if (!hasApproximation) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores); + } + } + + @Override + public int advanceShallow(int target) throws IOException { + if (disjunctionBlockPropagator != null) { + return disjunctionBlockPropagator.advanceShallow(target); + } + return super.advanceShallow(target); } /** @@ -55,7 +103,10 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept */ @Override public float score() throws IOException { - DisiWrapper topList = subScorersPQ.topList(); + return score(getSubMatches()); + } + + private float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue @@ -67,13 +118,30 @@ public float score() throws IOException { return totalScore; } + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorersPQ.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + /** * Return a DocIdSetIterator over matching documents. * @return DocIdSetIterator object */ @Override public DocIdSetIterator iterator() { - return new DisjunctionDISIApproximation(this.subScorersPQ); + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; } /** @@ -93,12 +161,28 @@ public float getMaxScore(int upTo) throws IOException { }).max(Float::compare).orElse(0.0f); } + @Override + public void setMinCompetitiveScore(float minScore) throws IOException { + if (disjunctionBlockPropagator != null) { + disjunctionBlockPropagator.setMinCompetitiveScore(minScore); + } + + for (Scorer scorer : subScorers) { + if (Objects.nonNull(scorer)) { + scorer.setMinCompetitiveScore(minScore); + } + } + } + /** * Returns the doc ID that is currently being scored. * @return document id */ @Override public int docID() { + if (subScorersPQ.size() == 0) { + return DocIdSetIterator.NO_MORE_DOCS; + } return subScorersPQ.top().doc; } @@ -169,4 +253,142 @@ private DisiPriorityQueue initializeSubScorersPQ() { } return subScorersPQ; } + + @Override + public Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + /** + * Object returned by {@link Scorer#twoPhaseIterator()} to provide an approximation of a {@link DocIdSetIterator}. + * After calling {@link DocIdSetIterator#nextDoc()} or {@link DocIdSetIterator#advance(int)} on the iterator + * returned by approximation(), you need to check {@link TwoPhaseIterator#matches()} to confirm if the retrieved + * document ID is a match. Implementation inspired by identical class for + * DisjunctionScorer + */ + static class TwoPhase extends TwoPhaseIterator { + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + DisiPriorityQueue subScorers; + boolean needsScores; + + private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + super(approximation); + this.matchCost = matchCost; + this.subScorers = subScorers; + unverifiedMatches = new PriorityQueue<>(subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + this.needsScores = needsScores; + } + + DisiWrapper getSubMatches() throws IOException { + for (DisiWrapper wrapper : unverifiedMatches) { + if (wrapper.twoPhaseView.matches()) { + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) { + DisiWrapper next = wrapper.next; + + if (Objects.isNull(wrapper.twoPhaseView)) { + // implicitly verified, move it to verifiedMatches + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; + + if (!needsScores) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(wrapper); + } + wrapper = next; + } + + if (Objects.nonNull(verifiedMatches)) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper wrapper = unverifiedMatches.pop(); + if (wrapper.twoPhaseView.matches()) { + wrapper.next = null; + verifiedMatches = wrapper; + return true; + } + } + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + /** + * A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports + * sub iterators that return empty results + */ + static class HybridSubqueriesDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator docIdSetIterator; + final DisiPriorityQueue subIterators; + + public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) { + docIdSetIterator = new DisjunctionDISIApproximation(subIterators); + this.subIterators = subIterators; + } + + @Override + public long cost() { + return docIdSetIterator.cost(); + } + + @Override + public int docID() { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.docID(); + } + + @Override + public int nextDoc() throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.nextDoc(); + } + + @Override + public int advance(final int target) throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.advance(target); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 69ee5015f..facb79694 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -5,10 +5,12 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -16,14 +18,16 @@ import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; + /** * Calculates query weights and build query scorers for hybrid query. */ public final class HybridQueryWeight extends Weight { - private final HybridQuery queries; // The Weights for our subqueries, in 1-1 correspondence private final List weights; @@ -34,7 +38,6 @@ public final class HybridQueryWeight extends Weight { */ public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(hybridQuery); - this.queries = hybridQuery; weights = hybridQuery.getSubQueries().stream().map(q -> { try { return searcher.createWeight(q, scoreMode, boost); @@ -65,6 +68,20 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { return MatchesUtils.fromSubMatches(mis); } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + List scorerSuppliers = new ArrayList<>(); + for (Weight w : weights) { + ScorerSupplier ss = w.scorerSupplier(context); + scorerSuppliers.add(ss); + } + + if (scorerSuppliers.isEmpty()) { + return null; + } + return new HybridScorerSupplier(scorerSuppliers, this, scoreMode); + } + /** * Create the scorer used to score our associated Query * @@ -75,19 +92,12 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { */ @Override public Scorer scorer(LeafReaderContext context) throws IOException { - List scorers = weights.stream().map(w -> { - try { - return w.scorer(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - // if there are no matches in any of the scorers (sub-queries) we need to return - // scorer as null to avoid problems with disi result iterators - if (scorers.stream().allMatch(Objects::isNull)) { + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { return null; } - return new HybridQueryScorer(this, scorers); + supplier.setTopLevelScoringClause(); + return supplier.get(Long.MAX_VALUE); } /** @@ -98,6 +108,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { + if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) { + // this situation should never happen, but in case it do such query will not be cached + return false; + } return weights.stream().allMatch(w -> w.isCacheable(ctx)); } @@ -113,4 +127,50 @@ public boolean isCacheable(LeafReaderContext ctx) { public Explanation explain(LeafReaderContext context, int doc) throws IOException { throw new UnsupportedOperationException("Explain is not supported"); } + + @RequiredArgsConstructor + static class HybridScorerSupplier extends ScorerSupplier { + private long cost = -1; + private final List scorerSuppliers; + private final Weight weight; + private final ScoreMode scoreMode; + + @Override + public Scorer get(long leadCost) throws IOException { + List tScorers = new ArrayList<>(); + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + tScorers.add(ss.get(leadCost)); + } else { + tScorers.add(null); + } + } + return new HybridQueryScorer(weight, tScorers, scoreMode); + } + + @Override + public long cost() { + if (cost == -1) { + long cost = 0; + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + cost += ss.cost(); + } + } + this.cost = cost; + } + return cost; + } + + @Override + public void setTopLevelScoringClause() throws IOException { + for (ScorerSupplier ss : scorerSuppliers) { + // sub scorers need to be able to skip too as calls to setMinCompetitiveScore get + // propagated + if (Objects.nonNull(ss)) { + ss.setTopLevelScoringClause(); + } + } + } + }; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java new file mode 100644 index 000000000..6b47a098d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; + +/** + * This class functions as a utility for propagating block boundaries within disjunctions. + * In disjunctions, where a match occurs if any subclause matches, a common approach might involve returning + * the minimum block boundary across all clauses. However, this method can introduce performance challenges, + * particularly when dealing with high minimum competitive scores and clauses with low scores that no longer + * significantly contribute to the iteration process. Therefore, this class computes block boundaries solely for clauses + * with a maximum score equal to or exceeding the minimum competitive score, or for the clause with the maximum + * score if such a clause is absent. + */ +public class HybridScoreBlockBoundaryPropagator { + + private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { + try { + return s.getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).thenComparing(s -> s.iterator().cost()); + + private final Scorer[] scorers; + private final float[] maxScores; + private int leadIndex = 0; + + HybridScoreBlockBoundaryPropagator(final Collection scorers) throws IOException { + this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); + for (Scorer scorer : this.scorers) { + scorer.advanceShallow(0); + } + Arrays.sort(this.scorers, MAX_SCORE_COMPARATOR); + + maxScores = new float[this.scorers.length]; + for (int i = 0; i < this.scorers.length; ++i) { + maxScores[i] = this.scorers[i].getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } + } + + /** See {@link Scorer#advanceShallow(int)}. */ + int advanceShallow(int target) throws IOException { + // For scorers that are below the lead index, just propagate. + for (int i = 0; i < leadIndex; ++i) { + Scorer s = scorers[i]; + if (s.docID() < target) { + s.advanceShallow(target); + } + } + + // For scorers above the lead index, we take the minimum + // boundary. + Scorer leadScorer = scorers[leadIndex]; + int upTo = leadScorer.advanceShallow(Math.max(leadScorer.docID(), target)); + + for (int i = leadIndex + 1; i < scorers.length; ++i) { + Scorer scorer = scorers[i]; + if (scorer.docID() <= target) { + upTo = Math.min(scorer.advanceShallow(target), upTo); + } + } + + // If the maximum scoring clauses are beyond `target`, then we use their + // docID as a boundary. It helps not consider them when computing the + // maximum score and get a lower score upper bound. + for (int i = scorers.length - 1; i > leadIndex; --i) { + Scorer scorer = scorers[i]; + if (scorer.docID() > target) { + upTo = Math.min(upTo, scorer.docID() - 1); + } else { + break; + } + } + return upTo; + } + + /** + * Set the minimum competitive score to filter out clauses that score less than this threshold. + * + * @see Scorer#setMinCompetitiveScore + */ + void setMinCompetitiveScore(float minScore) throws IOException { + // Update the lead index if necessary + while (leadIndex < maxScores.length - 1 && minScore > maxScores[leadIndex]) { + leadIndex++; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..4418841f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -19,7 +20,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.query.HybridQueryScorer; @@ -47,20 +47,55 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol } @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + public LeafCollector getLeafCollector(LeafReaderContext context) { docBase = context.docBase; - return new TopScoreDocCollector.ScorerLeafCollector() { + return new LeafCollector() { HybridQueryScorer compoundQueryScorer; @Override public void setScorer(Scorable scorer) throws IOException { - super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores"); + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + if (Objects.isNull(compoundQueryScorer)) { + log.error( + String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer) + ); + } + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (Objects.nonNull(hybridQueryScorer)) { + log.debug( + String.format( + Locale.ROOT, + "found hybrid query scorer, it's child of scorer %s", + childScorable.child.getClass().getSimpleName() + ) + ); + return hybridQueryScorer; + } + } + return null; } @Override public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java new file mode 100644 index 000000000..4e9070748 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AllArgsConstructor; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationInitializationException; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.List; + +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; + +/** + * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom + * collector manager for hybris query (pre phase) and reducing results (post phase) + */ +@AllArgsConstructor +public class HybridAggregationProcessor implements AggregationProcessor { + + private final AggregationProcessor delegateAggsProcessor; + + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); + + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException exception) { + throw new AggregationInitializationException("could not initialize hybrid aggregation processor", exception); + } + context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); + } + } + + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector + // managers is not called + // (https://github.com/opensearch-project/OpenSearch/blob/2.12/server/src/main/java/org/opensearch/search/query/QueryPhase.java#L333-L373) + // and we have to call it manually. This is required as we format final + // result of hybrid query in {@link HybridTopScoreCollector#reduce} + // when concurrent search is enabled then reduce method is called as part of the search {@see + // ConcurrentQueryPhaseSearcher#searchWithCollectorManager} + // corresponding call in Lucene + // https://github.com/apache/lucene/blob/branch_9_10/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java#L700 + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); + } + updateQueryResult(context.queryResult(), context); + } + + delegateAggsProcessor.postProcess(context); + } + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); + try { + collectorManager.reduce(List.of()).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); + } + } + + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..a5de898ab --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +/** + * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. + * In most cases it will be wrapped in MultiCollectorManager. + */ +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + + /** + * Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled. + * @param searchContext + * @return + * @throws IOException + */ + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ); + } + + @Override + public Collector newCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + return hybridcollector; + } + + /** + * Reduce the results from hybrid scores collector into a format specific for hybrid search query: + * - start + * - sub-query-delimiter + * - scores + * - stop + * Ignore other collectors if they are present in the context + * @param collectors collection of collectors after they has been executed and collected documents and scores + * @return search results that can be reduced be the caller + */ + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List hybridTopScoreDocCollectors = new ArrayList<>(); + // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper + // in case multiple collector managers are registered. We use hybrid scores collector to format scores into + // format specific for hybrid search query: start, sub-query-delimiter, scores, stop + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + /** + * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to + * use saved state of collector + */ + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + private final Collector scoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); + } + + @Override + public Collector newCollector() { + return scoreCollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + assert collectors.isEmpty() : "reduce on HybridCollectorNonConcurrentManager called with non-empty collectors"; + return super.reduce(List.of(scoreCollector)); + } + } + + /** + * Implementation of the HybridCollector that doesn't save collector's state and return new instance of every + * call of newCollector + */ + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..6461c698e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -4,46 +4,26 @@ */ package org.opensearch.neuralsearch.search.query; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; -import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.LinkedList; import java.util.List; -import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollector; -import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; -import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QueryPhaseSearcherWrapper; -import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.TopDocsCollectorContext; -import org.opensearch.search.rescore.RescoreContext; -import org.opensearch.search.sort.SortAndFormats; - -import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; @@ -66,15 +46,17 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } - validateQuery(searchContext, query); - return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + @VisibleForTesting + static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { @@ -103,7 +85,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) .allMatch(clause -> { return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery @@ -113,16 +95,17 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte return false; } - private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private boolean isWrappedHybridQuery(final Query query) { + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + @VisibleForTesting + protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { @@ -180,152 +163,14 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); - - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } - - final QuerySearchResult queryResult = searchContext.queryResult(); - - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); - - searcher.search(query, collector); - - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); - } - - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; - } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); - } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); - } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; - } - } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - - return new TotalHits(maxTotalHits, relation); - } - - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); - } - } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - private int getMaxDepthLimit(final SearchContext searchContext) { Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); } + + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 1a2e3f26e..a0a4c8ca3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.tests.util.TestUtil; @@ -219,13 +220,101 @@ public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { when(scorer.iterator()).thenReturn(iterator(docs)); when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); - HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); + IOException runtimeException = expectThrows(IOException.class, () -> new HybridQueryScorer(weight, Arrays.asList(scorer))); + assertTrue(runtimeException.getMessage().contains("Test exception")); + } + + @SneakyThrows + public void testApproximationIterator_whenSubScorerSupportsApproximation_thenSuccessful() { + final int maxDoc = TestUtil.nextInt(random(), 10, 1_000); + final int numDocs = TestUtil.nextInt(random(), 1, maxDoc / 2); + final Set uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDoc)); + } + final int[] docs = new int[numDocs]; + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores1 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores1[i] = random().nextFloat(); + } + final float[] scores2 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores2[i] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); - RuntimeException runtimeException = expectThrows( - RuntimeException.class, - () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) + HybridQueryScorer queryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorerWithTwoPhaseIterator(docs, scores1, fakeWeight(new MatchAllDocsQuery()), maxDoc), + scorerWithTwoPhaseIterator(docs, scores2, fakeWeight(new MatchNoDocsQuery()), maxDoc) + ) ); - assertTrue(runtimeException.getMessage().contains("Test exception")); + + int doc = -1; + int idx = 0; + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + doc = queryScorer.iterator().nextDoc(); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f); + } + idx++; + } + } + + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { + final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return new TwoPhaseIterator(iterator) { + + @Override + public boolean matches() { + return Arrays.binarySearch(docs, iterator.docID()) >= 0; + } + + @Override + public float matchCost() { + return 10; + } + }; + } + }; } private Pair generateDocuments(int maxDocId) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java new file mode 100644 index 000000000..5bf0948ea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class HybridScoreBlockBoundaryPropagatorTests extends OpenSearchQueryTestCase { + + public void testAdvanceShallow_whenMinCompetitiveScoreSet_thenSuccessful() throws IOException { + Scorer scorer1 = new MockScorer(10, 0.6f); + Scorer scorer2 = new MockScorer(40, 1.5f); + Scorer scorer3 = new MockScorer(30, 2f); + Scorer scorer4 = new MockScorer(120, 4f); + + List scorers = Arrays.asList(scorer1, scorer2, scorer3, scorer4); + Collections.shuffle(scorers, random()); + HybridScoreBlockBoundaryPropagator propagator = new HybridScoreBlockBoundaryPropagator(scorers); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.1f); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.8f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.4f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.9f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(2.5f); + assertEquals(120, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(7f); + assertEquals(120, propagator.advanceShallow(0)); + } + + private static class MockWeight extends Weight { + + MockWeight() { + super(new MatchNoDocsQuery()); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } + + private static class MockScorer extends Scorer { + + final int boundary; + final float maxScore; + + MockScorer(int boundary, float maxScore) throws IOException { + super(new MockWeight()); + this.boundary = boundary; + this.maxScore = maxScore; + } + + @Override + public int docID() { + return 0; + } + + @Override + public float score() { + throw new UnsupportedOperationException(); + } + + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public void setMinCompetitiveScore(float minCompetitiveScore) {} + + @Override + public float getMaxScore(int upTo) throws IOException { + return maxScore; + } + + @Override + public int advanceShallow(int target) { + assert target <= boundary; + return boundary; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index b67a1ee05..ad5a955c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -27,12 +28,15 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.util.PriorityQueue; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; @@ -399,4 +403,109 @@ public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsChildScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + Scorer scorer = mock(Scorer.class); + Collection childrenCollectors = List.of(new Scorable.ChildScorable(hybridQueryScorer, "MUST")); + when(scorer.getChildren()).thenReturn(childrenCollectors); + leafCollector.setScorer(scorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + leafCollector.setScorer(hybridQueryScorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java new file mode 100644 index 000000000..f44e762f0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { + + static final String TEXT_FIELD_NAME = "field"; + static final String TERM_QUERY_TEXT = "keyword"; + + @SneakyThrows + public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + hybridAggregationProcessor.preProcess(searchContext); + verify(mockAggsProcessorDelegate).preProcess(any()); + + hybridAggregationProcessor.postProcess(searchContext); + verify(mockAggsProcessorDelegate).postProcess(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verify(hybridCollectorManagerSpy).reduce(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verifyNoInteractions(hybridCollectorManagerSpy); + } + + @SneakyThrows + public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query termQuery = termSubQuery.toQuery(mockQueryShardContext); + + when(searchContext.query()).thenReturn(termQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridAggregationProcessor.postProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java new file mode 100644 index 000000000..65d6f3d8a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -0,0 +1,201 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoostingQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryWeight; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { + + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String QUERY1 = "hello"; + private static final float DELTA_FOR_ASSERTION = 0.001f; + + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testNewCollector_whenConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertNotSame(collector, secondCollector); + } + + @SneakyThrows + public void testReduce_whenMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + Object results = hybridCollectorManager.reduce(List.of()); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(1, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(4, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[2].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..2aebbb5d8 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -61,6 +63,7 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; @@ -159,7 +162,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows @@ -226,7 +229,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, never()).extractHybridQuery(any(), any()); } @SneakyThrows @@ -305,17 +308,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { assertEquals(1, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(1, compoundTopDocs.size()); - TopDocs subQueryTopDocs = compoundTopDocs.get(0); - assertEquals(1, subQueryTopDocs.totalHits.value); - assertNotNull(subQueryTopDocs.scoreDocs); - assertEquals(1, subQueryTopDocs.scoreDocs.length); - ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; assertNotNull(scoreDoc); int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); assertEquals(docId1, actualDocId); @@ -403,24 +397,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes assertEquals(4, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(10, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(3, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); - - TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); - List expectedIds3 = List.of(docId1, docId2, docId3, docId4); - assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + assertEquals(4, scoreDocs.length); + List expectedIds = List.of(0, 1, 2, 3); + List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); + assertEquals(expectedIds, actualDocIds); releaseResources(directory, w, reader); } @@ -726,20 +706,10 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then assertTrue(topDocs.totalHits.value > 0); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertTrue(scoreDocs.length > 0); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(2, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; + assertTrue(scoreDoc.score > 0); + assertEquals(0, scoreDoc.doc); releaseResources(directory, w, reader); } @@ -831,6 +801,15 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + SearchContext searchContext = mock(SearchContext.class); + AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); + assertNotNull(aggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value);