diff --git a/CHANGELOG.md b/CHANGELOG.md index 65d84cd5a..ccd1e6f94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x) ### Features +- Enable sorting and search_after features in Hybrid Search [#827](https://github.com/opensearch-project/neural-search/pull/827) ### Enhancements - Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/) - Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811)) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java index 9d3c3adb5..44666cc43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -4,18 +4,17 @@ */ package org.opensearch.neuralsearch.processor; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.stream.Collectors; - -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import lombok.AllArgsConstructor; import lombok.Getter; @@ -39,14 +38,14 @@ public class CompoundTopDocs { @Setter private List scoreDocs; - public CompoundTopDocs(final TotalHits totalHits, final List topDocs) { - initialize(totalHits, topDocs); + public CompoundTopDocs(final TotalHits totalHits, final List topDocs, final boolean isSortEnabled) { + initialize(totalHits, topDocs, isSortEnabled); } - private void initialize(TotalHits totalHits, List topDocs) { + private void initialize(TotalHits totalHits, List topDocs, boolean isSortEnabled) { this.totalHits = totalHits; this.topDocs = topDocs; - scoreDocs = cloneLargestScoreDocs(topDocs); + scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled); } /** @@ -74,9 +73,13 @@ private void initialize(TotalHits totalHits, List topDocs) { * 0, 9549511920.4881596047 */ public CompoundTopDocs(final TopDocs topDocs) { + boolean isSortEnabled = false; + if (topDocs instanceof TopFieldDocs) { + isSortEnabled = true; + } ScoreDoc[] scoreDocs = topDocs.scoreDocs; if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) { - initialize(topDocs.totalHits, new ArrayList<>()); + initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled); return; } // skipping first two elements, it's a start-stop element and delimiter for first series @@ -88,17 +91,22 @@ public CompoundTopDocs(final TopDocs topDocs) { if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) { ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]); TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO); - TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores); + TopDocs subQueryTopDocs; + if (isSortEnabled) { + subQueryTopDocs = new TopFieldDocs(totalHits, subQueryScores, ((TopFieldDocs) topDocs).fields); + } else { + subQueryTopDocs = new TopDocs(totalHits, subQueryScores); + } topDocsList.add(subQueryTopDocs); scoreDocList.clear(); } else { scoreDocList.add(scoreDoc); } } - initialize(topDocs.totalHits, topDocsList); + initialize(topDocs.totalHits, topDocsList, isSortEnabled); } - private List cloneLargestScoreDocs(final List docs) { + private List cloneLargestScoreDocs(final List docs, boolean isSortEnabled) { if (docs == null) { return null; } @@ -113,7 +121,20 @@ private List cloneLargestScoreDocs(final List docs) { maxScoreDocs = topDoc.scoreDocs; } } + // do deep copy - return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList()); + List scoreDocs = new ArrayList<>(); + for (ScoreDoc scoreDoc : maxScoreDocs) { + scoreDocs.add(deepCopyScoreDoc(scoreDoc, isSortEnabled)); + } + return scoreDocs; + } + + private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortEnabled) { + if (!isSortEnabled) { + return new ScoreDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + } + FieldDoc fieldDoc = (FieldDoc) scoreDoc; + return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index f317a9e12..c64f1c1f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -15,7 +15,11 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.FieldDoc; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; @@ -27,6 +31,8 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; +import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; /** * Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data @@ -62,13 +68,20 @@ public void execute( log.debug("Do score normalization"); scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + CombineScoresDto combineScoresDTO = CombineScoresDto.builder() + .queryTopDocs(queryTopDocs) + .scoreCombinationTechnique(combinationTechnique) + .querySearchResults(querySearchResults) + .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .build(); + // combine log.debug("Do score combination"); - scoreCombiner.combineScores(queryTopDocs, combinationTechnique); + scoreCombiner.combineScores(combineScoresDTO); // post-process data log.debug("Post-process query results after score normalization and combination"); - updateOriginalQueryResults(querySearchResults, queryTopDocs); + updateOriginalQueryResults(combineScoresDTO); updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); } @@ -96,7 +109,23 @@ private List getQueryTopDocs(final List quer return queryTopDocs; } - private void updateOriginalQueryResults(final List querySearchResults, final List queryTopDocs) { + private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) { + final List querySearchResults = combineScoresDTO.getQuerySearchResults(); + final List queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); + final Sort sort = combineScoresDTO.getSort(); + for (int index = 0; index < querySearchResults.size(); index++) { + QuerySearchResult querySearchResult = querySearchResults.get(index); + CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( + buildTopDocs(updatedTopDocs, sort), + maxScoreForShard(updatedTopDocs, sort != null) + ); + querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); + } + } + + private List getCompoundTopDocs(CombineScoresDto combineScoresDTO, List querySearchResults) { + final List queryTopDocs = combineScoresDTO.getQueryTopDocs(); if (querySearchResults.size() != queryTopDocs.size()) { throw new IllegalStateException( String.format( @@ -107,17 +136,42 @@ private void updateOriginalQueryResults(final List querySearc ) ); } - for (int index = 0; index < querySearchResults.size(); index++) { - QuerySearchResult querySearchResult = querySearchResults.get(index); - CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); - float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f; + return queryTopDocs; + } - // create final version of top docs with all updated values - TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0])); + /** + * Get Max score on Shard + * @param updatedTopDocs updatedTopDocs compound top docs on a shard + * @param isSortEnabled if sort is enabled or disabled + * @return max score + */ + private float maxScoreForShard(CompoundTopDocs updatedTopDocs, boolean isSortEnabled) { + if (updatedTopDocs.getTotalHits().value == 0 || updatedTopDocs.getScoreDocs().isEmpty()) { + return MAX_SCORE_WHEN_NO_HITS_FOUND; + } + if (isSortEnabled) { + float maxScore = MAX_SCORE_WHEN_NO_HITS_FOUND; + // In case of sorting iterate over score docs and deduce the max score + for (ScoreDoc scoreDoc : updatedTopDocs.getScoreDocs()) { + maxScore = Math.max(maxScore, scoreDoc.score); + } + return maxScore; + } + // If it is a normal hybrid query then first entry of score doc will have max score + return updatedTopDocs.getScoreDocs().get(0).score; + } - TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore); - querySearchResult.topDocs(updatedTopDocsAndMaxScore, null); + /** + * Get Top Docs on Shard + * @param updatedTopDocs compound top docs on a shard + * @param sort sort criteria + * @return TopDocs which will be instance of TopFieldDocs if sort is enabled. + */ + private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { + if (sort != null) { + return new TopFieldDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new FieldDoc[0]), sort.getSort()); } + return new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0])); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java new file mode 100644 index 000000000..c4783969b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.apache.lucene.search.Sort; +import org.opensearch.common.Nullable; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.search.query.QuerySearchResult; + +/** + * DTO object to hold data required for Score Combination. + */ +@AllArgsConstructor +@Builder +@Getter +public class CombineScoresDto { + @NonNull + private List queryTopDocs; + @NonNull + private ScoreCombinationTechnique scoreCombinationTechnique; + @NonNull + private List querySearchResults; + @Nullable + private Sort sort; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 09d9e83f2..a4e39f448 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -5,15 +5,24 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.Objects; +import java.util.Comparator; +import java.util.LinkedHashSet; + import java.util.stream.Collectors; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; @@ -23,33 +32,59 @@ */ @Log4j2 public class ScoreCombiner { + public static final Float MAX_SCORE_WHEN_NO_HITS_FOUND = 0.0f; + // Tie-breaker to merge multiple top docs + private static final Comparator SORTING_TIE_BREAKER = (o1, o2) -> { + int scoreComparison = Double.compare(o1.score, o2.score); + if (scoreComparison != 0) { + return scoreComparison; + } - private static final Float ZERO_SCORE = 0.0f; + int docIdComparison = Integer.compare(o1.doc, o2.doc); + if (docIdComparison != 0) { + return docIdComparison; + } + + // When duplicate result found then both score and doc ID are equal (o1.score == o2.score && o1.doc == o2.doc) then return 1 + return 1; + }; /** * Performs score combination based on input combination technique. Mutates input object by updating combined scores * Main steps we're doing for combination: - * - create map of normalized scores per doc id - * - using normalized scores create another map of combined scores per doc id - * - count max number of hits among sub-queries - * - sort documents by scores and take first "max number" of docs - * - update query search results with normalized scores - * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", - * other steps are same for all techniques. - * @param queryTopDocs query results that need to be normalized, mutated by method execution - * @param scoreCombinationTechnique exact combination method that should be applied + * - create map of normalized scores per doc id + * - using normalized scores create another map of combined scores per doc id + * - count max number of hits among sub-queries + * - sort documents by scores and take first "max number" of docs + * - update query search results with normalized scores + * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", + * other steps are same for all techniques. + * + * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. */ - public void combineScores(final List queryTopDocs, final ScoreCombinationTechnique scoreCombinationTechnique) { + public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from // multiple sub queries, doc ids may repeat for each sub query results - queryTopDocs.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs)); + combineScoresDTO.getQueryTopDocs() + .forEach( + compoundQueryTopDocs -> combineShardScores( + combineScoresDTO.getScoreCombinationTechnique(), + compoundQueryTopDocs, + combineScoresDTO.getSort() + ) + ); } - private void combineShardScores(final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs) { + private void combineShardScores( + final ScoreCombinationTechnique scoreCombinationTechnique, + final CompoundTopDocs compoundQueryTopDocs, + final Sort sort + ) { if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { return; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + // - create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); @@ -61,10 +96,92 @@ private void combineShardScores(final ScoreCombinationTechnique scoreCombination // - sort documents by scores and take first "max number" of docs // create a collection of doc ids that are sorted by their combined scores - List sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + Collection sortedDocsIds; + if (sort != null) { + sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); + } else { + sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + } // - update query search results with normalized scores - updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds); + updateQueryTopDocsWithCombinedScores( + compoundQueryTopDocs, + topDocsPerSubQuery, + combinedNormalizedScoresByDocId, + sortedDocsIds, + getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort), + sort != null + ); + } + + private boolean isSortOrderByScore(Sort sort) { + if (sort == null) { + return false; + } + + for (SortField sortField : sort.getSort()) { + if (SortField.Type.SCORE.equals(sortField.getType())) { + return true; + } + } + + return false; + } + + /** + * @param sort sort criteria + * @param topDocsPerSubQuery top docs per subquery + * @return list of top field docs which is deduced by typcasting top docs to top field docs. + */ + private List getTopFieldDocs(final Sort sort, final List topDocsPerSubQuery) { + if (sort == null) { + return null; + } + List topFieldDocs = new ArrayList<>(); + for (TopDocs topDocs : topDocsPerSubQuery) { + // Check for scoreDocs length. + // If scoreDocs length=0 then it means that no results are found for that particular subquery. + if (topDocs.scoreDocs.length != 0) { + topFieldDocs.add((TopFieldDocs) topDocs); + } + } + return topFieldDocs; + } + + /** + * @param compoundTopDocs top docs that represent on shard + * @param combinedNormalizedScoresByDocId docId to normalized scores map + * @param sort sort criteria + * @return map of docId and sort fields if sorting is enabled. + */ + private Map getDocIdSortFieldsMap( + final CompoundTopDocs compoundTopDocs, + final Map combinedNormalizedScoresByDocId, + final Sort sort + ) { + // If sort is null then no sort fields present therefore return null. + if (sort == null) { + return null; + } + // we're merging docs with normalized and combined scores. we need to have only maxHits results + Map docIdSortFieldMap = new HashMap<>(); + final List topFieldDocs = compoundTopDocs.getTopDocs(); + final boolean isSortByScore = isSortOrderByScore(sort); + for (TopDocs topDocs : topFieldDocs) { + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + FieldDoc fieldDoc = (FieldDoc) scoreDoc; + + if (docIdSortFieldMap.get(fieldDoc.doc) == null) { + // If sort by score then replace sort field value with normalized score. + if (isSortByScore) { + docIdSortFieldMap.put(fieldDoc.doc, new Object[] { combinedNormalizedScoresByDocId.get(fieldDoc.doc) }); + } else { + docIdSortFieldMap.put(fieldDoc.doc, fieldDoc.fields); + } + } + } + } + return docIdSortFieldMap; } private List getSortedDocIds(final Map combinedNormalizedScoresByDocId) { @@ -74,21 +191,77 @@ private List getSortedDocIds(final Map combinedNormaliz return sortedDocsIds; } + private Set getSortedDocIdsBySortCriteria(final List topFieldDocs, final Sort sort) { + if (Objects.isNull(topFieldDocs)) { + throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is enabled."); + } + // size will be equal to the number of score docs + int size = 0; + for (TopFieldDocs topFieldDoc : topFieldDocs) { + size += topFieldDoc.scoreDocs.length; + } + + // Merge the sorted results of individual queries to form a one final result per shard which is sorted. + // Input + // < 0, 0.7, shardId, [90]> //Query 1` result scoreDoc + // < 1, 0.7, shardId, [70]> //Query 1 result scoreDoc + // < 2, 0.3, shardId, [100]> //Query 2 result scoreDoc + // < 1, 0.3, shardId, [70]> //Query 2 result scoreDoc + + // Output + // < 2, 0.3, shardId, [100]> + // < 0, 0.7, shardId, [90]> + // < 1, 0.7, shardId, [70]> + // < 1, 0.3, shardId, [70]> + final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, size, topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER); + + // Remove duplicates from the sorted top docs. + Set uniqueDocIds = new LinkedHashSet<>(); + for (ScoreDoc scoreDoc : sortedTopDocs.scoreDocs) { + uniqueDocIds.add(scoreDoc.doc); + } + return uniqueDocIds; + } + private List getCombinedScoreDocs( final CompoundTopDocs compoundQueryTopDocs, final Map combinedNormalizedScoresByDocId, - final List sortedScores, - final long maxHits + final Collection sortedScores, + final long maxHits, + final Map docIdSortFieldMap, + boolean isSortingEnabled ) { + // ShardId will be -1 when index has multiple shards + int shardId = -1; + // ShardId will not be -1 in when index has single shard because Fetch phase gets executed before Normalization + if (!compoundQueryTopDocs.getScoreDocs().isEmpty()) { + shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex; + } List scoreDocs = new ArrayList<>(); - int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex; - for (int j = 0; j < maxHits && j < sortedScores.size(); j++) { - int docId = sortedScores.get(j); - scoreDocs.add(new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId)); + int hitCount = 0; + for (Integer docId : sortedScores) { + if (hitCount == maxHits) { + break; + } + scoreDocs.add(getScoreDoc(isSortingEnabled, docId, shardId, combinedNormalizedScoresByDocId, docIdSortFieldMap)); + hitCount++; } return scoreDocs; } + private ScoreDoc getScoreDoc( + final boolean isSortEnabled, + final int docId, + final int shardId, + final Map combinedNormalizedScoresByDocId, + final Map docIdSortFieldMap + ) { + if (isSortEnabled && docIdSortFieldMap != null) { + return new FieldDoc(docId, combinedNormalizedScoresByDocId.get(docId), docIdSortFieldMap.get(docId), shardId); + } + return new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); + } + public Map getNormalizedScoresPerDocument(final List topDocsPerSubQuery) { Map normalizedScoresPerDoc = new HashMap<>(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { @@ -118,13 +291,22 @@ private void updateQueryTopDocsWithCombinedScores( final CompoundTopDocs compoundQueryTopDocs, final List topDocsPerSubQuery, final Map combinedNormalizedScoresByDocId, - final List sortedScores + final Collection sortedScores, + Map docIdSortFieldMap, + boolean isSortingEnabled ) { // - max number of hits will be the same which are passed from QueryPhase long maxHits = compoundQueryTopDocs.getTotalHits().value; // - update query search results with normalized scores compoundQueryTopDocs.setScoreDocs( - getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits) + getCombinedScoreDocs( + compoundQueryTopDocs, + combinedNormalizedScoresByDocId, + sortedScores, + maxHits, + docIdSortFieldMap, + isSortingEnabled + ) ); compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits)); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java index 1299537bb..76822ee73 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -26,15 +26,15 @@ public HitsThresholdChecker(int totalHitsThreshold) { this.totalHitsThreshold = totalHitsThreshold; } - protected void incrementHitCount() { + public void incrementHitCount() { ++hitCount; } - protected boolean isThresholdReached() { + public boolean isThresholdReached() { return hitCount >= getTotalHitsThreshold(); } - protected ScoreMode scoreMode() { + public ScoreMode scoreMode() { return ScoreMode.TOP_SCORES; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java new file mode 100644 index 000000000..c1702996d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.collector; + +import java.util.List; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.TopDocs; + +/** + * Common interface class for Hybrid search collectors + */ +public interface HybridSearchCollector extends Collector { + /** + * @return List of topDocs which contains topDocs of individual subqueries. + */ + List topDocs(); + + /** + * @return count of total hits per shard + */ + int getTotalHits(); + + /** + * @return maxScore found on a shard + */ + float getMaxScore(); +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java new file mode 100644 index 000000000..2e268d37b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java @@ -0,0 +1,416 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.collector; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Locale; +import java.util.ArrayList; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.FieldValueHitQueue; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.util.PriorityQueue; +import org.opensearch.neuralsearch.query.HybridQueryScorer; +import org.opensearch.common.Nullable; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.lucene.MultiLeafFieldComparator; + +/* + Collects the TopFieldDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results. + The individual query results are sorted as per the sort criteria sent in the search request. + */ +@Log4j2 +public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector { + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final Sort sort; + @Nullable + private FieldDoc after; + private FieldComparator firstComparator; + // bottom would be set to null per shard. + private FieldValueHitQueue.Entry bottom; + @Getter + private int totalHits; + protected int docBase; + protected LeafFieldComparator comparators[]; + @Getter + @Setter + private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; + /* + reverseMul is used to set the direction of the sorting when creating comparators. + In threshold check reverseMul is used in comparison logic. + It modifies the comparison of either reverse or maintain the natural order depending on its value. + This ensures that the compareBottom method adjusts the order based on whether you want ascending or descending sorting. + */ + protected int reverseMul; + protected FieldValueHitQueue[] compoundScores; + protected boolean queueFull[]; + @Getter + protected float maxScore = 0.0f; + protected int[] collectedHits; + + // searchSortPartOfIndexSort is used to evaluate whether to perform index sort or not. + private Boolean searchSortPartOfIndexSort = null; + + private static final TopFieldDocs EMPTY_TOP_FIELD_DOCS = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[0], + new SortField[0] + ); + + // Declaring the constructor private prevents extending this class by anyone + // else. Note that the class cannot be final since it's extended by the + // internal versions. If someone will define a constructor with any other + // visibility, then anyone will be able to extend the class, which is not what + // we want. + HybridTopFieldDocSortCollector( + final int numHits, + final HitsThresholdChecker hitsThresholdChecker, + final Sort sort, + final FieldDoc after + ) { + this.numHits = numHits; + this.hitsThresholdChecker = hitsThresholdChecker; + this.sort = sort; + this.after = after; + } + + /** + * HybridCollectorManager fetches the topDocs in the reduce method. + * @return List of TopFieldDocs which represents results of Top Docs of individual subquery. + */ + public List topDocs() { + if (compoundScores == null) { + return new ArrayList<>(); + } + + List topFieldDocs = new ArrayList<>(); + for (int subQueryNumber = 0; subQueryNumber < compoundScores.length; subQueryNumber++) { + topFieldDocs.add( + topDocsPerQuery( + 0, + Math.min(collectedHits[subQueryNumber], compoundScores[subQueryNumber].size()), + compoundScores[subQueryNumber], + collectedHits[subQueryNumber], + sort.getSort() + ) + ); + } + return topFieldDocs; + } + + @Override + public ScoreMode scoreMode() { + return hitsThresholdChecker.scoreMode(); + } + + protected abstract class HybridTopDocSortLeafCollector implements LeafCollector { + protected HybridQueryScorer compoundQueryScorer; + private boolean collectedAllCompetitiveHits = false; + + /** + 1. initializeComparators method needs to be initialized once per shard. + 2. Also, after initializing for every segment the comparators has to be refreshed. + Therefore, to do the above two things lazily we have to use a flag initializeLeafComparatorsPerSegmentOnce which is set to true when a leafCollector is initialized per segment. + Later, in the collect method when number of sub-queries has been found then initialize the comparators(1) or (2) refresh the comparators and set the flag to false. + */ + private boolean initializeLeafComparatorsPerSegmentOnce; + + public HybridTopDocSortLeafCollector() { + this.initializeLeafComparatorsPerSegmentOnce = true; + } + + @Override + public void setScorer(final Scorable scorer) throws IOException { + if (scorer instanceof HybridQueryScorer) { + log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores"); + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + if (Objects.isNull(compoundQueryScorer)) { + log.error( + String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer) + ); + } + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (Objects.nonNull(hybridQueryScorer)) { + log.debug( + String.format( + Locale.ROOT, + "found hybrid query scorer, it's child of scorer %s", + childScorable.child.getClass().getSimpleName() + ) + ); + return hybridQueryScorer; + } + } + return null; + } + + /* + Increment total hit count and validate if threshold is reached. + */ + protected void incrementTotalHitCount() throws IOException { + totalHits++; + hitsThresholdChecker.incrementHitCount(); + if (scoreMode().isExhaustive() == false + && getTotalHitsRelation() == TotalHits.Relation.EQUAL_TO + && hitsThresholdChecker.isThresholdReached()) { + // for the first time hitsThreshold is reached, notify all comparators about this + for (LeafFieldComparator comparator : comparators) { + comparator.setHitsThresholdReached(); + } + setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + } + } + + /* + Collect hit and add the value of the sort field in the comparator. + */ + protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float score) throws IOException { + // Startup transient: queue hasn't gathered numHits yet + int slot = hitsCollected - 1; + // Copy hit into queue + if (numHits > 0) { + comparators[subQueryNumber].copy(slot, doc); + add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score); + if (queueFull[subQueryNumber]) { + comparators[subQueryNumber].setBottom(bottom.slot); + } + } else { + queueFull[subQueryNumber] = true; + } + } + + /* + * This hit is competitive - replace bottom element in queue & adjustTop + */ + protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException { + // This hit is competitive - replace bottom element in queue & adjustTop + if (numHits > 0) { + comparators[subQueryNumber].copy(bottom.slot, doc); + updateBottom(doc, compoundScores[subQueryNumber]); + comparators[subQueryNumber].setBottom(bottom.slot); + } + } + + protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException { + if (collectedAllCompetitiveHits || reverseMul * comparators[subQueryNumber].compareBottom(doc) <= 0) { + // since docs are visited in doc Id order, if compare is 0, it means + // this document is larger than anything else in the queue, and + // therefore not competitive. + if (searchSortPartOfIndexSort) { + if (hitsThresholdChecker.isThresholdReached()) { + setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + log.info("Terminating collection as hits threshold is reached"); + throw new CollectionTerminatedException(); + } else { + collectedAllCompetitiveHits = true; + } + } + return true; + } + return false; + } + + /* + The method initializes once per search request. + */ + protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException { + if (compoundScores == null) { + compoundScores = new FieldValueHitQueue[numberOfSubQueries]; + comparators = new LeafFieldComparator[numberOfSubQueries]; + queueFull = new boolean[numberOfSubQueries]; + collectedHits = new int[numberOfSubQueries]; + for (int i = 0; i < numberOfSubQueries; i++) { + initializeLeafFieldComparators(context, i); + } + } + if (initializeLeafComparatorsPerSegmentOnce) { + for (int i = 0; i < numberOfSubQueries; i++) { + initializeComparators(context, i); + } + initializeLeafComparatorsPerSegmentOnce = false; + } + } + + private void initializeLeafFieldComparators(LeafReaderContext context, int subQueryNumber) throws IOException { + compoundScores[subQueryNumber] = FieldValueHitQueue.create(sort.getSort(), numHits); + firstComparator = compoundScores[subQueryNumber].getComparators()[0]; + + // Optimize the sort + if (compoundScores[subQueryNumber].getComparators().length == 1) { + firstComparator.setSingleSort(); + } + + if (after != null) { + setAfterFieldValueInFieldCompartor(subQueryNumber); + } + } + + /* This method initializes the comparators per segment + */ + private void initializeComparators(LeafReaderContext context, int subQueryNumber) throws IOException { + // as all segments are sorted in the same way, enough to check only the 1st segment for indexSort + if (searchSortPartOfIndexSort == null) { + Sort indexSort = context.reader().getMetaData().getSort(); + searchSortPartOfIndexSort = canEarlyTerminate(sort, indexSort); + if (searchSortPartOfIndexSort) { + firstComparator.disableSkipping(); + } + } + + LeafFieldComparator[] leafFieldComparators = compoundScores[subQueryNumber].getComparators(context); + int[] reverseMuls = compoundScores[subQueryNumber].getReverseMul(); + if (leafFieldComparators.length == 1) { + reverseMul = reverseMuls[0]; + comparators[subQueryNumber] = leafFieldComparators[0]; + } else { + reverseMul = 1; + comparators[subQueryNumber] = new MultiLeafFieldComparator(leafFieldComparators, reverseMuls); + } + comparators[subQueryNumber].setScorer(compoundQueryScorer); + } + + private void setAfterFieldValueInFieldCompartor(int subQueryNumber) { + FieldComparator[] fieldComparators = compoundScores[subQueryNumber].getComparators(); + for (int k = 0; k < fieldComparators.length; k++) { + @SuppressWarnings("unchecked") + FieldComparator fieldComparator = (FieldComparator) fieldComparators[k]; + fieldComparator.setTopValue(after.fields[k]); + } + } + } + + /* + TopFieldDocs per subquery + */ + private TopFieldDocs topDocsPerQuery( + int start, + int howMany, + PriorityQueue pq, + int totalHits, + SortField[] sortFields + ) { + if (howMany < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", howMany) + ); + } + + if (start < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", howMany, start) + ); + } + + if (start >= howMany || howMany == 0) { + return EMPTY_TOP_FIELD_DOCS; + } + + int size = howMany - start; + ScoreDoc[] results = new ScoreDoc[size]; + + // Get the requested results from pq. + populateResults(results, size, pq); + + return new TopFieldDocs(new TotalHits(totalHits, totalHitsRelation), results, sortFields); + } + + /* + Results are converted in the FieldDocs and the value of the field on which the sorting is applied has been added in the FieldDoc. + */ + private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue pq) { + FieldValueHitQueue queue = (FieldValueHitQueue) pq; + for (int i = howMany - 1; i >= 0 && pq.size() > 0; i--) { + // adding to array if index is within [0..array_length - 1] + if (i < results.length) { + FieldValueHitQueue.Entry entry = queue.pop(); + final int n = queue.getComparators().length; + final Object[] fields = new Object[n]; + for (int j = 0; j < n; ++j) { + fields[j] = queue.getComparators()[j].value(entry.slot); + } + + results[i] = new FieldDoc(entry.doc, entry.score, fields); + } + } + } + + // Add the entry in the Priority queue + private void add(int slot, int doc, FieldValueHitQueue compoundScore, int subQueryNumber, float score) { + FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc); + bottomEntry.score = score; + bottom = compoundScore.add(bottomEntry); + // The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case + // slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits + // on the current page) and slot = hitsCollected - 1. + assert slot < numHits; + boolean isQueueFull = false; + if (slot == (numHits - 1)) { + isQueueFull = true; + } + queueFull[subQueryNumber] = isQueueFull; + } + + private void updateBottom(int doc, FieldValueHitQueue compoundScore) { + bottom.doc = docBase + doc; + bottom = compoundScore.updateTop(); + } + + private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) { + return canEarlyTerminateOnDocId(searchSort) || canEarlyTerminateOnPrefix(searchSort, indexSort); + } + + private boolean canEarlyTerminateOnDocId(Sort searchSort) { + final SortField[] fields1 = searchSort.getSort(); + return SortField.FIELD_DOC.equals(fields1[0]); + } + + private boolean canEarlyTerminateOnPrefix(Sort searchSort, Sort indexSort) { + if (indexSort != null) { + final SortField[] searchSortField = searchSort.getSort(); + final SortField[] indexSortField = indexSort.getSort(); + // early termination is possible if fields1 is a prefix of fields2 + if (searchSortField.length > indexSortField.length) { + return false; + } + // Compare fields1 and the corresponding prefix of fields2 + for (int i = 0; i < searchSortField.length; i++) { + if (!searchSortField[i].equals(indexSortField[i])) { + return false; + } + } + return true; + } + return false; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java similarity index 97% rename from src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java rename to src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java index 85fc15bf4..01a4cdfff 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.search.collector; import java.io.IOException; import java.util.ArrayList; @@ -12,7 +12,6 @@ import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Scorable; @@ -24,12 +23,13 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.query.HybridQueryScorer; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; /** * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results */ @Log4j2 -public class HybridTopScoreDocCollector implements Collector { +public class HybridTopScoreDocCollector implements HybridSearchCollector { private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); private int docBase; private final HitsThresholdChecker hitsThresholdChecker; diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/PagingFieldCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/PagingFieldCollector.java new file mode 100644 index 000000000..d571685cb --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/PagingFieldCollector.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.collector; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Sort; +import org.opensearch.common.Nullable; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; + +/** + * PagingFieldCollector collects the sorted results at the shard level for every individual query + * as per search_after criteria applied in the search request. + * It collects the list of TopFieldDocs. + */ +public final class PagingFieldCollector extends HybridTopFieldDocSortCollector { + private final FieldDoc after; + + public PagingFieldCollector(int numHits, HitsThresholdChecker hitsThresholdChecker, Sort sort, @Nullable FieldDoc after) { + super(numHits, hitsThresholdChecker, sort, after); + this.after = after; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) { + docBase = context.docBase; + final int afterDoc = after.doc - docBase; + return new HybridTopDocSortLeafCollector() { + @Override + public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } + float[] subScoresByQuery = compoundQueryScorer.hybridScores(); + initializePriorityQueuesWithComparators(context, subScoresByQuery.length); + incrementTotalHitCount(); + for (int i = 0; i < subScoresByQuery.length; i++) { + float score = subScoresByQuery[i]; + // if score is 0.0 there is no hits for that sub-query + if (score == 0) { + continue; + } + + // if queueFull[i] is true then it indicates + // that we have found the results equal to the size sent in the search request. + if (queueFull[i]) { + // If threshold is reached then return. Default value of threshold is 10000. + if (thresholdCheck(doc, i)) { + return; + } + } + + // Lets understand the below logic with example + // Consider there are 30 results without applying `search_after` + // and out of 30, 10 are the results user is seeking after applying `search_after` + // Therefore when those 10 results are collected the resultsFoundOnPreviousPage. + // the search_after parameter to retrieve the next page of hits using a set of sort values from the previous page. + // https://opensearch.org/docs/latest/search-plugins/searching-data/paginate/#the-search_after-parameter + boolean resultsFoundOnPreviousPage = checkIfSearchAfterResultsAreFound(i, doc); + if (resultsFoundOnPreviousPage) { + return; + } + maxScore = Math.max(score, maxScore); + if (queueFull[i]) { + collectCompetitiveHit(doc, i); + } else { + collectedHits[i]++; + collectHit(doc, collectedHits[i], i, score); + } + + } + + } + + /** + * It compares reverseMultiplier with the topValue in the comparator to determine + * if the document should be included based on its position relative to the previous top value. + * @param subQueryNumber + * @param doc + * @return + * @throws IOException + */ + private boolean checkIfSearchAfterResultsAreFound(int subQueryNumber, int doc) throws IOException { + final int topComparison = reverseMul * comparators[subQueryNumber].compareTop(doc); + if (topComparison > 0 || (topComparison == 0 && doc <= afterDoc)) { + // Already collected on a previous page + return true; + } + return false; + } + }; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/SimpleFieldCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/SimpleFieldCollector.java new file mode 100644 index 000000000..16cc6f0d7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/SimpleFieldCollector.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.collector; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Sort; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; + +/* + SimpleFieldCollector collects the sorted results at the shard level for every individual query. + It collects the list of TopFieldDocs. + */ +public final class SimpleFieldCollector extends HybridTopFieldDocSortCollector { + + public SimpleFieldCollector(int numHits, HitsThresholdChecker hitsThresholdChecker, Sort sort) { + super(numHits, hitsThresholdChecker, sort, null); + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) { + docBase = context.docBase; + + return new HybridTopDocSortLeafCollector() { + @Override + public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } + float[] subScoresByQuery = compoundQueryScorer.hybridScores(); + initializePriorityQueuesWithComparators(context, subScoresByQuery.length); + incrementTotalHitCount(); + for (int i = 0; i < subScoresByQuery.length; i++) { + float score = subScoresByQuery[i]; + // if score is 0.0 there is no hits for that sub-query + if (score == 0) { + continue; + } + maxScore = Math.max(score, maxScore); + if (queueFull[i]) { + if (thresholdCheck(doc, i)) { + return; + } + collectCompetitiveHit(doc, i); + } else { + collectedHits[i]++; + collectHit(doc, collectedHits[i], i, score); + } + } + } + }; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/lucene/MultiLeafFieldComparator.java b/src/main/java/org/opensearch/neuralsearch/search/lucene/MultiLeafFieldComparator.java new file mode 100644 index 000000000..7480ca2e0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/lucene/MultiLeafFieldComparator.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.lucene; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Scorable; + +/* +HybridMultiLeafFieldComparator holds information when sort criteria is applied on more that one field. +This class is same as of lucene. Because lucene implementation does not have public access we have to add it in neural search plugin. +https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/search/MultiLeafFieldComparator.java + */ +public final class MultiLeafFieldComparator implements LeafFieldComparator { + + private final LeafFieldComparator[] comparators; + private final int[] reverseMul; + // we extract the first comparator to avoid array access in the common case + // that the first comparator compares worse than the bottom entry in the queue + private final LeafFieldComparator firstComparator; + private final int firstReverseMul; + + public MultiLeafFieldComparator(LeafFieldComparator[] comparators, int[] reverseMul) { + if (comparators.length != reverseMul.length) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Must have the same number of comparators and reverseMul, got %s and %s", + comparators.length, + reverseMul.length + ) + ); + } + this.comparators = comparators; + this.reverseMul = reverseMul; + this.firstComparator = comparators[0]; + this.firstReverseMul = reverseMul[0]; + } + + // Set the bottom slot, ie the "weakest" (sorted last) entry in the queue. When compareBottom is called, you should compare against this + // slot. This will always be called before compareBottom. + @Override + public void setBottom(int slot) throws IOException { + for (LeafFieldComparator comparator : comparators) { + comparator.setBottom(slot); + } + } + + // Compare the bottom of the queue with this doc. This will only invoked after setBottom has been called. This should return the same + // result as FieldComparator.compare(int, int)} as if bottom were slot1 and the new document were slot 2. + // For a search that hits many results, this method will be the hotspot (invoked by far the most frequently). + @Override + public int compareBottom(int doc) throws IOException { + // Compare the first comparator's result with reverse multiplier + int comparison = firstReverseMul * firstComparator.compareBottom(doc); + if (comparison != 0) { + return comparison; + } + // Loop through remaining comparators and compare + for (int i = 1; i < comparators.length; ++i) { + comparison = reverseMul[i] * comparators[i].compareBottom(doc); + if (comparison != 0) { + return comparison; + } + } + return 0; + } + + // Compare the top value with this doc. This will only invoked after setTopValue has been called. + // This should return the same result as FieldComparator.compare(int, int)} as if topValue were slot1 and the new document were slot 2. + // This is only called for searches that use searchAfter (deep paging). + @Override + public int compareTop(int doc) throws IOException { + // Compare the first comparator's result with reverse multiplier + int comparison = firstReverseMul * firstComparator.compareTop(doc); + if (comparison != 0) { + return comparison; + } + for (int i = 1; i < comparators.length; ++i) { + comparison = reverseMul[i] * comparators[i].compareTop(doc); + if (comparison != 0) { + return comparison; + } + } + return 0; + } + + // This method is called when a new hit is competitive. + // You should copy any state associated with this document that will be required for future comparisons, into the specified slot. + @Override + public void copy(int slot, int doc) throws IOException { + for (LeafFieldComparator comparator : comparators) { + comparator.copy(slot, doc); + } + } + + // Sets the Scorer to use in case a document's score is needed. + @Override + public void setScorer(final Scorable scorer) throws IOException { + for (LeafFieldComparator comparator : comparators) { + comparator.setScorer(scorer); + } + } + + // nforms this leaf comparator that hits threshold is reached. + // This method is called from a collector when hits threshold is reached. + @Override + public void setHitsThresholdReached() throws IOException { + // this is needed for skipping functionality that is only relevant for the 1st comparator + firstComparator.setHitsThresholdReached(); + } + + // Returns a competitive iterator + // Returns: an iterator over competitive docs that are stronger than already collected docs or null if such an iterator is not available + // for the current comparator or segment. + @Override + public DocIdSetIterator competitiveIterator() throws IOException { + // this is needed for skipping functionality that is only relevant for the 1st comparator + return firstComparator.competitiveIterator(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 08e7bf657..4eb49e845 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -4,21 +4,29 @@ */ package org.opensearch.neuralsearch.search.query; +import java.util.Locale; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Weight; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.Weight; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.FieldDoc; import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; +import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; +import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; +import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; @@ -35,9 +43,12 @@ import java.util.Objects; import static org.apache.lucene.search.TotalHits.Relation; -import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES; + import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults; /** * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. @@ -52,8 +63,10 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { - final List hybridTopScoreDocCollectors = getHybridScoreDocCollectors(collectors); - if (hybridTopScoreDocCollectors.isEmpty()) { - throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + final List hybridSearchCollectors = getHybridSearchCollectors(collectors); + if (hybridSearchCollectors.isEmpty()) { + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper collectors"); } + return reduceSearchResults(getSearchResults(hybridSearchCollectors)); + } + private List getSearchResults(final List hybridSearchCollectors) { List results = new ArrayList<>(); DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); - for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) { - List topDocs = hybridTopScoreDocCollector.topDocs(); - TopDocs newTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), - topDocs - ); - TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); + for (HybridSearchCollector collector : hybridSearchCollectors) { + TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats); + results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats)); + } + return results; + } - results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats, newTopDocs)); + private TopDocsAndMaxScore getTopDocsAndAndMaxScore( + final HybridSearchCollector hybridSearchCollector, + final DocValueFormat[] docValueFormats + ) { + TopDocs newTopDocs; + List topDocs = hybridSearchCollector.topDocs(); + if (docValueFormats != null) { + newTopDocs = getNewTopFieldDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), + topDocs, + sortAndFormats.sort.getSort() + ); + } else { + newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); } - return reduceSearchResults(results); + return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore()); } - private List getHybridScoreDocCollectors(Collection collectors) { - final List hybridTopScoreDocCollectors = new ArrayList<>(); - // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper - // in case multiple collector managers are registered. We use hybrid scores collector to format scores into - // format specific for hybrid search query: start, sub-query-delimiter, scores, stop + private List getHybridSearchCollectors(final Collection collectors) { + final List hybridSearchCollectors = new ArrayList<>(); for (final Collector collector : collectors) { if (collector instanceof MultiCollectorWrapper) { for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { - if (sub instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + if (sub instanceof HybridTopScoreDocCollector || sub instanceof HybridTopFieldDocSortCollector) { + hybridSearchCollectors.add((HybridSearchCollector) sub); } } - } else if (collector instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } else if (collector instanceof HybridTopScoreDocCollector || collector instanceof HybridTopFieldDocSortCollector) { + hybridSearchCollectors.add((HybridSearchCollector) collector); } else if (collector instanceof FilteredCollector - && ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector()); - } + && (((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector + || ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector)) { + hybridSearchCollectors.add((HybridSearchCollector) ((FilteredCollector) collector).getCollector()); + } + } + return hybridSearchCollectors; + } + + private static void validateSortCriteria(SearchContext searchContext, boolean trackScores) { + SortField[] sortFields = searchContext.sort().sort.getSort(); + boolean hasFieldSort = false; + boolean hasScoreSort = false; + for (SortField sortField : sortFields) { + SortField.Type type = sortField.getType(); + if (type.equals(SortField.Type.SCORE)) { + hasScoreSort = true; + } else { + hasFieldSort = true; + } + if (hasScoreSort && hasFieldSort) { + break; + } + } + if (hasScoreSort && hasFieldSort) { + throw new IllegalArgumentException( + "_score sort criteria cannot be applied with any other criteria. Please select one sort criteria out of them." + ); + } + if (trackScores && hasFieldSort) { + throw new IllegalArgumentException( + "Hybrid search results when sorted by any field, docId or _id, track_scores must be set to false." + ); + } + if (trackScores && hasScoreSort) { + throw new IllegalArgumentException("Hybrid search results are by default sorted by _score, track_scores must be set to false."); + } + } + + private void validateSearchAfterFieldAndSortFormats() { + if (after.fields == null) { + throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search"); + } + + if (after.fields.length != sortAndFormats.sort.getSort().length) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "after.fields has %s values but sort has %s", + after.fields.length, + sortAndFormats.sort.getSort().length + ) + ); } - return hybridTopScoreDocCollectors; } private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { @@ -202,10 +295,11 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List top return new TopDocs(totalHits, scoreDocs); } - private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final long maxTotalHits) { + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final long maxTotalHits) { final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED ? Relation.GREATER_THAN_OR_EQUAL_TO : Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { return new TotalHits(0, relation); } @@ -213,15 +307,69 @@ private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDo return new TotalHits(maxTotalHits, relation); } + private TopDocs getNewTopFieldDocs(final TotalHits totalHits, final List topFieldDocs, final SortField sortFields[]) { + if (Objects.isNull(topFieldDocs)) { + return new TopFieldDocs(totalHits, new FieldDoc[0], sortFields); + } + + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topFieldDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topFieldDoc -> topFieldDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopFieldDocs(totalHits, new FieldDoc[0], sortFields); + } + + // format scores using following template: + // consider the sort is applied for two fields. + // consider field1 type is integer and field2 type is float. + // doc_id | magic_number_1 | [1,1.0f] + // doc_id | magic_number_2 | [1,1.0f] + // ... + // doc_id | magic_number_2 | [1,1.0f] + // ... + // doc_id | magic_number_2 | [1,1.0f] + // ... + // doc_id | magic_number_1 | [1,1.0f] + final Object[] sortFieldsForDelimiterResults = createSortFieldsForDelimiterResults(sortFields); + List result = new ArrayList<>(); + result.add(createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults)); + for (TopFieldDocs topFieldDoc : topFieldDocs) { + if (Objects.isNull(topFieldDoc) || Objects.isNull(topFieldDoc.scoreDocs)) { + result.add(createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults)); + continue; + } + + List fieldDocsPerQuery = new ArrayList<>(); + for (ScoreDoc scoreDoc : topFieldDoc.scoreDocs) { + fieldDocsPerQuery.add((FieldDoc) scoreDoc); + } + result.add(createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults)); + result.addAll(fieldDocsPerQuery); + } + result.add(createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults)); + + FieldDoc[] fieldDocs = result.toArray(new FieldDoc[0]); + + return new TopFieldDocs(totalHits, fieldDocs, sortFields); + } + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { return sortAndFormats == null ? null : sortAndFormats.formats; } private void reduceCollectorResults( - QuerySearchResult result, - TopDocsAndMaxScore topDocsAndMaxScore, - DocValueFormat[] docValueFormats, - TopDocs newTopDocs + final QuerySearchResult result, + final TopDocsAndMaxScore topDocsAndMaxScore, + final DocValueFormat[] docValueFormats ) { // this is case of first collector, query result object doesn't have any top docs set, so we can // just set new top docs without merge @@ -233,7 +381,7 @@ private void reduceCollectorResults( } // in this case top docs are already present in result, and we need to merge next result object with what we have. // if collector doesn't have any hits we can just skip it and save some cycles by not doing merge - if (newTopDocs.totalHits.value == 0) { + if (topDocsAndMaxScore.topDocs.totalHits.value == 0) { return; } // we need to do actual merge because query result and current collector both have some score hits @@ -247,7 +395,7 @@ private void reduceCollectorResults( * @param results collection of search results * @return single search result that represents all results as one object */ - private ReduceableSearchResult reduceSearchResults(List results) { + private ReduceableSearchResult reduceSearchResults(final List results) { return (result) -> { for (ReduceableSearchResult r : results) { // call reduce for results of each single collector, this will update top docs in query result @@ -268,9 +416,18 @@ public HybridCollectorNonConcurrentManager( HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, - Weight filteringWeight + Weight filteringWeight, + ScoreDoc searchAfter ) { - super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES); + super( + numHits, + hitsThresholdChecker, + trackTotalHitsUpTo, + sortAndFormats, + filteringWeight, + new TopDocsMerger(sortAndFormats), + (FieldDoc) searchAfter + ); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -297,9 +454,18 @@ public HybridCollectorConcurrentSearchManager( HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, - Weight filteringWeight + Weight filteringWeight, + ScoreDoc searchAfter ) { - super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES); + super( + numHits, + hitsThresholdChecker, + trackTotalHitsUpTo, + sortAndFormats, + filteringWeight, + new TopDocsMerger(sortAndFormats), + (FieldDoc) searchAfter + ); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java new file mode 100644 index 000000000..d09750dfb --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import java.util.Comparator; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.SortField; + +/** + * Comparator class that compares two field docs as per the sorting criteria + */ +@RequiredArgsConstructor(access = AccessLevel.PACKAGE) +class HybridQueryFieldDocComparator implements Comparator { + final SortField[] sortFields; + final FieldComparator[] comparators; + final int[] reverseMul; + final Comparator tieBreaker; + + public HybridQueryFieldDocComparator(SortField[] sortFields, Comparator tieBreaker) { + this.sortFields = sortFields; + this.tieBreaker = tieBreaker; + comparators = new FieldComparator[sortFields.length]; + reverseMul = new int[sortFields.length]; + for (int compIDX = 0; compIDX < sortFields.length; compIDX++) { + final SortField sortField = sortFields[compIDX]; + comparators[compIDX] = sortField.getComparator(1, Pruning.NONE); + reverseMul[compIDX] = sortField.getReverse() ? -1 : 1; + } + } + + @Override + public int compare(final FieldDoc firstFD, final FieldDoc secondFD) { + for (int compIDX = 0; compIDX < comparators.length; compIDX++) { + final FieldComparator comp = comparators[compIDX]; + + final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]); + + if (cmp != 0) { + return cmp; + } + } + return tieBreakCompare(firstFD, secondFD, tieBreaker); + } + + private int tieBreakCompare(ScoreDoc firstDoc, ScoreDoc secondDoc, Comparator tieBreaker) { + assert tieBreaker != null; + int value = tieBreaker.compare(firstDoc, secondDoc); + return value; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java index 7eb6e2b55..1895d1d79 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java @@ -6,6 +6,7 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import java.util.ArrayList; @@ -31,9 +32,11 @@ class HybridQueryScoreDocsMerger { * Method returns new object and doesn't mutate original ScoreDocs arrays. * @param sourceScoreDocs original score docs from query result * @param newScoreDocs new score docs that we need to merge into existing scores + * @param comparator comparator to compare the score docs + * @param isSortEnabled flag that show if sort is enabled or disabled * @return merged array of ScoreDocs objects */ - public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator) { + public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator, final boolean isSortEnabled) { if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC || Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) { throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements"); @@ -58,7 +61,7 @@ public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Compar && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer]) && newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { - if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) { + if (compareCondition(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer], comparator, isSortEnabled)) { mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); sourcePointer++; } else { @@ -78,6 +81,23 @@ && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { } // mark end of hybrid query results by end element mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]); + if (isSortEnabled) { + return mergedScoreDocs.toArray((T[]) new FieldDoc[0]); + } return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]); } + + private boolean compareCondition( + final ScoreDoc oldScoreDoc, + final ScoreDoc secondScoreDoc, + final Comparator comparator, + final boolean isSortEnabled + ) { + // If sorting is enabled then compare condition will be different then normal HybridQuery + if (isSortEnabled) { + return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) < 0; + } else { + return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) >= 0; + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 0e6adfb1a..a77ff458e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -7,27 +7,46 @@ import com.google.common.annotations.VisibleForTesting; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import java.util.Comparator; import java.util.Objects; +import org.opensearch.search.sort.SortAndFormats; /** * Utility class for merging TopDocs and MaxScore across multiple search queries */ @RequiredArgsConstructor(access = AccessLevel.PACKAGE) class TopDocsMerger { - - private final HybridQueryScoreDocsMerger scoreDocsMerger; + private HybridQueryScoreDocsMerger docsMerger; + private SortAndFormats sortAndFormats; + @VisibleForTesting + protected static Comparator SCORE_DOC_BY_SCORE_COMPARATOR; @VisibleForTesting - protected static final Comparator SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + protected static HybridQueryFieldDocComparator FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR; + private final Comparator MERGING_TIE_BREAKER = (o1, o2) -> { + int docIdComparison = Integer.compare(o1.doc, o2.doc); + return docIdComparison; + }; + /** * Uses hybrid query score docs merger to merge internal score docs */ - static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>()); + TopDocsMerger(final SortAndFormats sortAndFormats) { + this.sortAndFormats = sortAndFormats; + if (isSortingEnabled()) { + docsMerger = new HybridQueryScoreDocsMerger(); + FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER); + } else { + docsMerger = new HybridQueryScoreDocsMerger<>(); + SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + } + } /** * Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. @@ -35,34 +54,19 @@ class TopDocsMerger { * @param newTopDocs TopDocsAndMaxScore for the new query * @return merged TopDocsAndMaxScore object */ - public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { return source; } - // we need to merge hits per individual sub-query - // format of results in both new and source TopDocs is following - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( - source.topDocs.scoreDocs, - newTopDocs.topDocs.scoreDocs, - SCORE_DOC_BY_SCORE_COMPARATOR - ); TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); TopDocsAndMaxScore result = new TopDocsAndMaxScore( - new TopDocs(mergedTotalHits, mergedScoreDocs), + getTopDocs(getMergedScoreDocs(source.topDocs.scoreDocs, newTopDocs.topDocs.scoreDocs), mergedTotalHits), Math.max(source.maxScore, newTopDocs.maxScore) ); return result; } - private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { // merged value is a lower bound - if both are equal_to than merged will also be equal_to, // otherwise assign greater_than_or_equal TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO @@ -71,4 +75,46 @@ private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxSco : TotalHits.Relation.EQUAL_TO; return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation); } + + private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits) { + if (isSortingEnabled()) { + return new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()); + } + return new TopDocs(mergedTotalHits, mergedScoreDocs); + } + + private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs) { + // Case 1 when sorting is enabled then below will be the TopDocs format + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 | [1] + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_1 | [1] + + // Case 2 when sorting is disabled then below will be the TopDocs format + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + return docsMerger.merge(source, newScoreDocs, comparator(), isSortingEnabled()); + } + + private Comparator comparator() { + return sortAndFormats != null ? FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR : SCORE_DOC_BY_SCORE_COMPARATOR; + } + + private boolean isSortingEnabled() { + return sortAndFormats != null; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java index 8fc71056a..fa196a533 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java @@ -7,6 +7,10 @@ import java.util.Objects; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.util.BytesRef; /** * Utility class for handling format of Hybrid Search query results @@ -53,6 +57,19 @@ public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) { return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0; } + public static FieldDoc createFieldDocStartStopElementForHybridSearchResults(final int docId, final Object[] fields) { + return new FieldDoc(docId, MAGIC_NUMBER_START_STOP, fields); + } + + /** + * Create ScoreDoc object that is a delimiter element between sub-query results in hybrid search query results + * @param docId id of one of docs from actual result object, or -1 if there are no matches + * @return + */ + public static FieldDoc createFieldDocDelimiterElementForHybridSearchResults(final int docId, final Object[] fields) { + return new FieldDoc(docId, MAGIC_NUMBER_DELIMITER, fields); + } + /** * Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores * @param scoreDoc score doc object to check on @@ -76,4 +93,49 @@ public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) { } return !isHybridQuerySpecialElement(scoreDoc); } + + /** + * This method is for creating dummy sort object for the field docs having magic number scores which acts as delimiters. + * The sort object should be in the same type of the field on which sorting criteria is applied. + * @param fields contains the information about the object type of the field on which sorting criteria is applied + * @return + */ + public static Object[] createSortFieldsForDelimiterResults(final Object[] fields) { + final Object[] sortFields = new Object[fields.length]; + for (int i = 0; i < fields.length; i++) { + SortField sortField = (SortField) fields[i]; + SortField.Type type = sortField.getType(); + if (sortField instanceof SortedNumericSortField) { + type = ((SortedNumericSortField) sortField).getNumericType(); + } + // Example: Lets consider there are 2 sort fields on which the sort criteria has to be applied. + // + // + // ... + // + // ` + Object SORT_FIELDS_FOR_DELIMITER_RESULTS; + switch (type) { + case DOC: + case INT: + SORT_FIELDS_FOR_DELIMITER_RESULTS = 1; + break; + case LONG: + SORT_FIELDS_FOR_DELIMITER_RESULTS = 1L; + break; + case SCORE: + case FLOAT: + SORT_FIELDS_FOR_DELIMITER_RESULTS = 1.0f; + break; + case DOUBLE: + SORT_FIELDS_FOR_DELIMITER_RESULTS = 1.0; + break; + default: + SORT_FIELDS_FOR_DELIMITER_RESULTS = new BytesRef(); + } + + sortFields[i] = SORT_FIELDS_FOR_DELIMITER_RESULTS; + } + return sortFields; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java new file mode 100644 index 000000000..fb7ca53a6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.util; + +import java.util.List; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.sort.SortedWiderNumericSortField; + +/** + * Utility class for evaluating and creating sort criteria + */ +public class HybridSearchSortUtil { + + /** + * @param querySearchResults list of query search results where each search result represents a result from the shard. + * @param queryTopDocs list of top docs which have results with top scores. + * @return sort criteria + */ + public static Sort evaluateSortCriteria(final List querySearchResults, final List queryTopDocs) { + if (!checkIfSortEnabled(querySearchResults)) { + return null; + } + return createSort(getTopFieldDocs(queryTopDocs)); + } + + // Check if sort is enabled by checking docValueFormats Object + private static boolean checkIfSortEnabled(final List querySearchResults) { + if (querySearchResults == null || querySearchResults.isEmpty() || querySearchResults.get(0) == null) { + throw new IllegalArgumentException("shard results cannot be null in the normalization process."); + } + return querySearchResults.get(0).sortValueFormats() != null; + } + + // Get the topFieldDocs array from the first shard result + private static TopFieldDocs[] getTopFieldDocs(final List queryTopDocs) { + // loop over queryTopDocs and return the first set of topFieldDocs found + // Considering the topDocs can be empty if no result is found on the shard therefore we need iterate over all the shards . + for (CompoundTopDocs compoundTopDocs : queryTopDocs) { + if (compoundTopDocs == null) { + throw new IllegalArgumentException("CompoundTopDocs cannot be null in the normalization process"); + } + if (containsTopFieldDocs(compoundTopDocs.getTopDocs())) { + return compoundTopDocs.getTopDocs().toArray(new TopFieldDocs[0]); + } + } + return new TopFieldDocs[0]; + } + + private static boolean containsTopFieldDocs(List topDocs) { + // topDocs can be empty if no results found in the shard + if (topDocs == null || topDocs.isEmpty()) { + return false; + } + for (TopDocs topDoc : topDocs) { + if (topDoc != null && topDoc instanceof TopFieldDocs) { + return true; + } + } + return false; + } + + /** + * Creates Sort object from topFieldsDocs fields. + * It is necessary to widen the SortField.Type to maximum byte size for merging sorted docs. + * Different indices might have different types. This will avoid user to do re-index of data + * in case of mapping field change for newly indexed data. + * This will support Int to Long and Float to Double. + * Earlier widening of type was taken care in IndexNumericFieldData, but since we now want to + * support sort optimization, we removed type widening there and taking care here during merging. + * More details here https://github.com/opensearch-project/OpenSearch/issues/6326 + */ + private static Sort createSort(TopFieldDocs[] topFieldDocs) { + final SortField[] firstTopDocFields = topFieldDocs[0].fields; + final SortField[] newFields = new SortField[firstTopDocFields.length]; + + for (int i = 0; i < firstTopDocFields.length; i++) { + final SortField delegate = firstTopDocFields[i]; + final SortField.Type sortFieldType = delegate instanceof SortedNumericSortField + ? ((SortedNumericSortField) delegate).getNumericType() + : delegate.getType(); + + if (SortedWiderNumericSortField.isTypeSupported(sortFieldType) && isSortWideningRequired(topFieldDocs, i)) { + newFields[i] = new SortedWiderNumericSortField(delegate.getField(), sortFieldType, delegate.getReverse()); + } else { + newFields[i] = firstTopDocFields[i]; + } + } + return new Sort(newFields); + } + + /** + * It will compare respective SortField between shards to see if any shard results have different + * field mapping type, accordingly it will decide to widen the sort fields. + */ + private static boolean isSortWideningRequired(TopFieldDocs[] topFieldDocs, int sortFieldindex) { + for (int i = 0; i < topFieldDocs.length - 1; i++) { + TopFieldDocs currentTopFieldDoc = topFieldDocs[i]; + TopFieldDocs nextTopFieldDoc = topFieldDocs[i + 1]; + if (currentTopFieldDoc == null || nextTopFieldDoc == null) { + throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is applied"); + } + if (!currentTopFieldDoc.fields[sortFieldindex].equals(nextTopFieldDoc.fields[sortFieldindex])) { + return true; + } + } + return false; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index 0096f7f94..3b2f64063 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -28,7 +28,7 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { new ScoreDoc(5, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs, false); assertNotNull(compoundTopDocs); assertEquals(topDocs, compoundTopDocs.getTopDocs()); } @@ -44,7 +44,8 @@ public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() { new ScoreDoc(4, RandomUtils.nextFloat()), new ScoreDoc(5, RandomUtils.nextFloat()) } ) - ) + ), + false ); assertNotNull(hybridQueryScoreTopDocs); assertNotNull(hybridQueryScoreTopDocs.getScoreDocs()); @@ -58,20 +59,21 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit new ScoreDoc[] { new ScoreDoc(2, RandomUtils.nextFloat()), new ScoreDoc(4, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs, false); assertNotNull(compoundTopDocs); assertNotNull(compoundTopDocs.getScoreDocs()); assertEquals(2, compoundTopDocs.getScoreDocs().size()); } public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null, false); assertNotNull(compoundTopDocs); assertNull(compoundTopDocs.getScoreDocs()); CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), - Arrays.asList(null, null) + Arrays.asList(null, null), + false ); assertNotNull(compoundTopDocsWithNullArray); assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 7c443a825..e93c9b9ec 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -45,7 +45,6 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -172,7 +171,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio createStartStopElementForHybridSearchResults(4) } ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), null); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -240,7 +239,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination createStartStopElementForHybridSearchResults(10) } ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 25.438505f), new DocValueFormat[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 25.438505f), null); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -318,7 +317,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul new TotalHits(4, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), null); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -382,7 +381,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz createStartStopElementForHybridSearchResults(4) } ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), null); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -471,7 +470,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() createStartStopElementForHybridSearchResults(4) } ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), null); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 5d88ffed9..59fb51563 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -26,7 +26,6 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -66,7 +65,7 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC ), 0.5f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -108,7 +107,7 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { ), 0.0f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -155,7 +154,7 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo ), 0.5f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -215,7 +214,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom ), 0.5f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -275,7 +274,7 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe ), 0.5f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); @@ -327,7 +326,7 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu ), 0.5f ), - new DocValueFormat[0] + null ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index c97abe1a4..918f3f45b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.Collections; +import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.util.List; @@ -19,7 +21,13 @@ public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreCombiner scoreCombiner = new ScoreCombiner(); - scoreCombiner.combineScores(List.of(), ScoreCombinationFactory.DEFAULT_METHOD); + scoreCombiner.combineScores( + CombineScoresDto.builder() + .queryTopDocs(List.of()) + .scoreCombinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .querySearchResults(Collections.emptyList()) + .build() + ); } public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() { @@ -37,7 +45,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -47,18 +56,26 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new TotalHits(4, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), List.of( new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) - ) + ), + false ) ); - scoreCombiner.combineScores(queryTopDocs, ScoreCombinationFactory.DEFAULT_METHOD); + scoreCombiner.combineScores( + CombineScoresDto.builder() + .queryTopDocs(queryTopDocs) + .scoreCombinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .querySearchResults(Collections.emptyList()) + .build() + ); assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 34f0af48c..67abd552f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -29,7 +29,8 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(1, TotalHits.Relation.EQUAL_TO), - List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })) + List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })), + false ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -59,7 +60,8 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe new TotalHits(3, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } ) - ) + ), + false ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -95,7 +97,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) - ) + ), + false ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -143,7 +146,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -153,14 +157,16 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new TotalHits(4, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), List.of( new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) - ) + ), + false ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index bccc0820a..ba4bfee0d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -30,7 +30,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -44,7 +45,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc(2, l2Norm(scores[0], Arrays.asList(scores))), new ScoreDoc(4, l2Norm(scores[1], Arrays.asList(scores))) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -75,7 +77,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(4, scoresQuery2[1]), new ScoreDoc(2, scoresQuery2[2]) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -97,7 +100,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(4, l2Norm(scoresQuery2[1], Arrays.asList(scoresQuery2))), new ScoreDoc(2, l2Norm(scoresQuery2[2], Arrays.asList(scoresQuery2))) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -128,7 +132,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(4, scoresShard1and2Query3[1]), new ScoreDoc(2, scoresShard1and2Query3[2]) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -146,7 +151,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(10, scoresShard1and2Query3[5]), new ScoreDoc(15, scoresShard1and2Query3[6]) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -168,7 +174,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(4, l2Norm(scoresShard1and2Query3[1], Arrays.asList(scoresShard1and2Query3))), new ScoreDoc(2, l2Norm(scoresShard1and2Query3[2], Arrays.asList(scoresShard1and2Query3))) } ) - ) + ), + false ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -189,7 +196,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(10, l2Norm(scoresShard1and2Query3[5], Arrays.asList(scoresShard1and2Query3))), new ScoreDoc(15, l2Norm(scoresShard1and2Query3[6], Arrays.asList(scoresShard1and2Query3))) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index 1fe1edf25..d0445f0ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -28,7 +28,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -40,7 +41,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -66,7 +68,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new TotalHits(3, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -83,7 +86,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new TotalHits(3, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -108,7 +112,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new TotalHits(3, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) - ) + ), + false ), new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -118,7 +123,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } ) - ) + ), + false ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -135,7 +141,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new TotalHits(3, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) - ) + ), + false ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -146,7 +153,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, 0.001f) } ) - ) + ), + false ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 4bc40add8..9df106156 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -212,7 +212,10 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - rangeFilterQuery + rangeFilterQuery, + null, + false, + null ); assertHitResultsFromQuery(1, searchResponseAsMap); @@ -224,6 +227,9 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, + null, + null, + false, null ); assertHitResultsFromQuery(2, searchResponseAsMap); @@ -235,7 +241,10 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - rangeFilterQuery + rangeFilterQuery, + null, + false, + null ); assertHitResultsFromQuery(2, searchResponseAsMap); } else { @@ -246,6 +255,9 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, + null, + null, + false, null ); assertHitResultsFromQuery(3, searchResponseAsMap); @@ -304,7 +316,10 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - rangeFilterQuery + rangeFilterQuery, + null, + false, + null ); assertHitResultsFromQuery(1, searchResponseAsMap); @@ -316,6 +331,9 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, + null, + null, + false, null ); assertHitResultsFromQuery(2, searchResponseAsMap); @@ -327,7 +345,10 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - rangeFilterQuery + rangeFilterQuery, + null, + false, + null ); assertHitResultsFromQuery(4, searchResponseAsMap); } else { @@ -338,6 +359,9 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, + null, + null, + false, null ); assertHitResultsFromQuery(3, searchResponseAsMap); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 8f8ae8cc4..ea951b65f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -174,7 +174,10 @@ private void testPostFilterRangeQuery(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } @@ -256,7 +259,10 @@ private void testPostFilterBoolQuery(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); // Case 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query), aggregation (Average stock price @@ -269,7 +275,10 @@ private void testPostFilterBoolQuery(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), List.of(aggsBuilder), - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); Map aggregations = getAggregations(searchResponseAsMap); @@ -291,7 +300,10 @@ private void testPostFilterBoolQuery(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); // Case 4 A Query with a combination of hybrid query (Match Query, Range Query) and a post filter query (Bool Query with a should @@ -309,7 +321,10 @@ private void testPostFilterBoolQuery(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } @@ -364,7 +379,10 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); @@ -378,7 +396,10 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { 10, Map.of("search_pipeline", SEARCH_PIPELINE), null, - postFilterQuery + postFilterQuery, + null, + false, + null ); assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java new file mode 100644 index 000000000..f3615b991 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -0,0 +1,696 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Collections; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import lombok.SneakyThrows; +import org.junit.BeforeClass; +import org.opensearch.client.ResponseException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQueryWhenSortIsEnabled; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.ScoreSortBuilder; +import org.opensearch.search.sort.FieldSortBuilder; + +public class HybridQuerySortIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-hybrid-sort-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-sort-multi-doc-index-single-shard"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-sort-pipeline"; + private static final String INTEGER_FIELD_1_STOCK = "stock"; + private static final String TEXT_FIELD_1_NAME = "name"; + private static final String KEYWORD_FIELD_2_CATEGORY = "category"; + private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 2"; + private static final String TEXT_FIELD_VALUE_3_MI_1 = "Mission Impossible 1"; + private static final String TEXT_FIELD_VALUE_4_MI_2 = "Mission Impossible 2"; + private static final String TEXT_FIELD_VALUE_5_TERMINAL = "The Terminal"; + private static final String TEXT_FIELD_VALUE_6_AVENGERS = "Avengers"; + private static final int INTEGER_FIELD_STOCK_1_25 = 25; + private static final int INTEGER_FIELD_STOCK_2_22 = 22; + private static final int INTEGER_FIELD_STOCK_3_256 = 256; + private static final int INTEGER_FIELD_STOCK_4_25 = 25; + private static final int INTEGER_FIELD_STOCK_5_20 = 20; + private static final String KEYWORD_FIELD_CATEGORY_1_DRAMA = "Drama"; + private static final String KEYWORD_FIELD_CATEGORY_2_ACTION = "Action"; + private static final String KEYWORD_FIELD_CATEGORY_3_SCI_FI = "Sci-fi"; + private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; + private static final int SHARDS_COUNT_IN_MULTI_NODE_CLUSTER = 3; + private static final int LTE_OF_RANGE_IN_HYBRID_QUERY = 400; + private static final int GTE_OF_RANGE_IN_HYBRID_QUERY = 20; + private static final int SMALLEST_STOCK_VALUE_IN_QUERY_RESULT = 20; + private static final int LARGEST_STOCK_VALUE_IN_QUERY_RESULT = 400; + private static final int LARGEST_STOCK_VALUE_IN_SEARCH_AFTER_MULTIPLE_FIELD_QUERY_RESULT = 25; + private static final int LARGEST_STOCK_VALUE_IN_SEARCH_AFTER_SINGLE_FIELD_QUERY_RESULT = 22; + + @BeforeClass + @SneakyThrows + public static void setUpCluster() { + // we need new instance because we're calling non-static methods from static method. + // main purpose is to minimize network calls, initialization is only needed once + HybridQuerySortIT instance = new HybridQuerySortIT(); + instance.initClient(); + instance.updateClusterSettings(); + } + + @SneakyThrows + public void testSortOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSortOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testScoreSort_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSortOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testScoreSort_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSortOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testScoreSort_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true); + } + + @SneakyThrows + private void testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + fieldSortOrderMap.put("_doc", SortOrder.ASC); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, false); + assertDocValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.ASC, 0, false, false); + } + + @SneakyThrows + public void testSingleFieldSort_whenTrackScoresIsEnabled_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + assertThrows( + "Hybrid search results when sorted by any field, docId or _id, track_scores must be set to false.", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + true, + null + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSingleFieldSort_whenSortCriteriaIsByScoreAndField_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + fieldSortOrderMap.put("_score", SortOrder.DESC); + assertThrows( + "_score sort criteria cannot be applied with any other criteria. Please select one sort criteria out of them.", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + true, + null + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSearchAfterWithSortOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testSearchAfter_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testSearchAfter_whenMultipleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSearchAfterWithSortOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testSearchAfter_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testSearchAfter_whenMultipleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSearchAfterWithSortOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testSearchAfter_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testSearchAfter_whenMultipleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSearchAfterWithSortOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testSearchAfter_whenSingleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testSearchAfter_whenMultipleFieldSort_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void testSearchAfter_whenSingleFieldSort_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + List searchAfter = new ArrayList<>(); + searchAfter.add(25); + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 3, 6); + assertStockValueWithSortOrderInHybridQueryResults( + nestedHits, + SortOrder.DESC, + LARGEST_STOCK_VALUE_IN_SEARCH_AFTER_SINGLE_FIELD_QUERY_RESULT, + true, + true + ); + } + + @SneakyThrows + private void testSearchAfter_whenMultipleFieldSort_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + fieldSortOrderMap.put("_doc", SortOrder.DESC); + List searchAfter = new ArrayList<>(); + searchAfter.add(25); + searchAfter.add(4); + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 5, 6); + assertStockValueWithSortOrderInHybridQueryResults( + nestedHits, + SortOrder.DESC, + LARGEST_STOCK_VALUE_IN_SEARCH_AFTER_MULTIPLE_FIELD_QUERY_RESULT, + true, + false + ); + assertDocValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, 0, false, false); + } + + @SneakyThrows + private void testScoreSort_whenSingleFieldSort_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("_score", SortOrder.DESC); + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertScoreWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, 1.0); + } + + @SneakyThrows + public void testSort_whenSortFieldsSizeNotEqualToSearchAfterSize_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + List searchAfter = new ArrayList<>(); + searchAfter.add(25); + searchAfter.add(0); + assertThrows( + "after.fields has 2 values but sort has 1", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + true, + searchAfter + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSearchAfter_whenAfterFieldIsNotPassed_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + Map fieldSortOrderMap = new LinkedHashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + List searchAfter = new ArrayList<>(); + searchAfter.add(null); + assertThrows( + "after.fields wasn't set; you must pass fillFields=true for the previous search", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + true, + searchAfter + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(termQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private List> createSortBuilders(Map fieldSortOrderMap, boolean isSortByScore) { + List> sortBuilders = new ArrayList<>(); + if (fieldSortOrderMap != null) { + for (Map.Entry entry : fieldSortOrderMap.entrySet()) { + FieldSortBuilder fieldSortBuilder = SortBuilders.fieldSort(entry.getKey()).order(entry.getValue()); + sortBuilders.add(fieldSortBuilder); + } + } + + if (isSortByScore) { + ScoreSortBuilder scoreSortBuilder = SortBuilders.scoreSort().order(SortOrder.ASC); + sortBuilders.add(scoreSortBuilder); + } + return sortBuilders; + } + + private void assertStockValueWithSortOrderInHybridQueryResults( + List> hitsNestedList, + SortOrder sortOrder, + int baseStockValue, + boolean isPrimarySortField, + boolean isSingleFieldSort + ) { + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + List sorts = (List) oneHit.get("sort"); + int stock = (int) source.get(INTEGER_FIELD_1_STOCK); + if (isPrimarySortField) { + int stockValueInSort = (int) sorts.get(0); + if (sortOrder == SortOrder.DESC) { + assertTrue("Stock value is sorted as per sort order", stock <= baseStockValue); + } else { + assertTrue("Stock value is sorted as per sort order", stock >= baseStockValue); + } + assertEquals(stock, stockValueInSort); + } + if (!isSingleFieldSort) { + assertNotNull(sorts.get(1)); + int stockValueInSort = (int) sorts.get(0); + assertEquals(stock, stockValueInSort); + } + baseStockValue = stock; + } + } + + private void assertDocValueWithSortOrderInHybridQueryResults( + List> hitsNestedList, + SortOrder sortOrder, + int baseDocIdValue, + boolean isPrimarySortField, + boolean isSingleFieldSort + ) { + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + List sorts = (List) oneHit.get("sort"); + if (isPrimarySortField) { + int docId = (int) sorts.get(0); + if (sortOrder == SortOrder.DESC) { + assertTrue("Doc Id value is sorted as per sort order", docId <= baseDocIdValue); + } else { + assertTrue("Doc Id value is sorted as per sort order", docId >= baseDocIdValue); + } + baseDocIdValue = docId; + } + if (!isSingleFieldSort) { + assertNotNull(sorts.get(1)); + } + } + } + + private void assertScoreWithSortOrderInHybridQueryResults( + List> hitsNestedList, + SortOrder sortOrder, + double baseScore + ) { + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + double score = (double) oneHit.get("_score"); + if (sortOrder == SortOrder.DESC) { + assertTrue("Stock value is sorted by descending sort order", score <= baseScore); + } else { + assertTrue("Stock value is sorted by ascending sort order", score >= baseScore); + } + baseScore = score; + } + } + + private List> validateHitsCountAndFetchNestedHits( + Map searchResponseAsMap, + int collectHitCountExpected, + int resultsExpected + ) { + assertHitResultsFromQueryWhenSortIsEnabled(collectHitCountExpected, resultsExpected, searchResponseAsMap); + return getNestedHits(searchResponseAsMap); + } + + @SneakyThrows + void prepareResourcesBeforeTestExecution(int numShards) { + if (numShards == 1) { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, numShards); + } else { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, numShards); + } + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + } + + @SneakyThrows + private void initializeIndexIfNotExists(String indexName, int numShards) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(INTEGER_FIELD_1_STOCK), + Collections.singletonList(KEYWORD_FIELD_2_CATEGORY), + Collections.emptyList(), + numShards + ), + "" + ); + + addKnnDoc( + indexName, + "1", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(INTEGER_FIELD_1_STOCK), + Collections.singletonList(INTEGER_FIELD_STOCK_1_25), + Collections.singletonList(KEYWORD_FIELD_2_CATEGORY), + Collections.singletonList(KEYWORD_FIELD_CATEGORY_1_DRAMA), + Collections.emptyList(), + Collections.emptyList() + ); + + addKnnDoc( + indexName, + "2", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(INTEGER_FIELD_1_STOCK), + Collections.singletonList(INTEGER_FIELD_STOCK_2_22), + Collections.singletonList(KEYWORD_FIELD_2_CATEGORY), + Collections.singletonList(KEYWORD_FIELD_CATEGORY_1_DRAMA), + Collections.emptyList(), + Collections.emptyList() + ); + + addKnnDoc( + indexName, + "3", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_3_MI_1), + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(INTEGER_FIELD_1_STOCK), + Collections.singletonList(INTEGER_FIELD_STOCK_3_256), + Collections.singletonList(KEYWORD_FIELD_2_CATEGORY), + Collections.singletonList(KEYWORD_FIELD_CATEGORY_2_ACTION), + Collections.emptyList(), + Collections.emptyList() + ); + + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_4_MI_2), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_4_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_5_TERMINAL), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_6_AVENGERS), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_3_SCI_FI), + List.of(), + List.of() + ); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java index 94f1e7207..08efd0811 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java @@ -421,7 +421,10 @@ private void testSumAggsAndRangePostFilter() throws IOException { 10, Map.of("search_pipeline", SEARCH_PIPELINE), List.of(aggsBuilder), - rangeFilterQuery + rangeFilterQuery, + null, + false, + null ); Map aggregations = getAggregations(searchResponseAsMap); diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java new file mode 100644 index 000000000..3bb0e6bcd --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import org.apache.commons.lang3.RandomUtils; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.query.HybridQueryScorer; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; +import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; +import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; + +public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase { + static final String TEXT_FIELD_NAME = "field"; + static final String INT_FIELD_NAME = "integerField"; + static final String DOC_FIELD_NAME = "_doc"; + private static final String TEST_QUERY_TEXT = "greeting"; + private static final String TEST_QUERY_TEXT2 = "salute"; + private static final int NUM_DOCS = 4; + private static final int NUM_HITS = 1; + private static final int TOTAL_HITS_UP_TO = 1000; + + private static final int DOC_ID_1 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_2 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_3 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_4 = RandomUtils.nextInt(0, 100_000); + private static final String FIELD_1_VALUE = "text1"; + private static final String FIELD_2_VALUE = "text2"; + private static final String FIELD_3_VALUE = "text3"; + private static final String FIELD_4_VALUE = "text4"; + + @SneakyThrows + public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + List documents = new ArrayList<>(); + Document document1 = new Document(); + document1.add(new NumericDocValuesField("_id", DOC_ID_1)); + document1.add(new IntField(INT_FIELD_NAME, 100, Field.Store.YES)); + document1.add(new TextField(TEXT_FIELD_NAME, FIELD_1_VALUE, Field.Store.YES)); + documents.add(document1); + Document document2 = new Document(); + document2.add(new NumericDocValuesField("_id", DOC_ID_2)); + document2.add(new IntField(INT_FIELD_NAME, 200, Field.Store.YES)); + document2.add(new TextField(TEXT_FIELD_NAME, FIELD_2_VALUE, Field.Store.YES)); + documents.add(document2); + Document document3 = new Document(); + document3.add(new NumericDocValuesField("_id", DOC_ID_3)); + document3.add(new IntField(INT_FIELD_NAME, 300, Field.Store.YES)); + document3.add(new TextField(TEXT_FIELD_NAME, FIELD_3_VALUE, Field.Store.YES)); + documents.add(document3); + Document document4 = new Document(); + document4.add(new NumericDocValuesField("_id", DOC_ID_4)); + document4.add(new IntField(INT_FIELD_NAME, 400, Field.Store.YES)); + document4.add(new TextField(TEXT_FIELD_NAME, FIELD_4_VALUE, Field.Store.YES)); + documents.add(document4); + w.addDocuments(documents); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + SortField sortField = new SortField(DOC_FIELD_NAME, SortField.Type.DOC); + HybridTopFieldDocSortCollector hybridTopFieldDocSortCollector = new SimpleFieldCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO), + new Sort(sortField) + ); + Weight weight = mock(Weight.class); + hybridTopFieldDocSortCollector.setWeight(weight); + LeafCollector leafCollector = hybridTopFieldDocSortCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + int[] docIdsForQuery = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3, DOC_ID_4 }; + Arrays.sort(docIdsForQuery); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)))) + ); + + leafCollector.setScorer(hybridQueryScorer); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + + int doc = iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + leafCollector.collect(doc); + doc = iterator.nextDoc(); + } + + List topFieldDocs = hybridTopFieldDocSortCollector.topDocs(); + + assertNotNull(topFieldDocs); + assertEquals(1, topFieldDocs.size()); + for (TopFieldDocs topFieldDoc : topFieldDocs) { + // assert results for each sub-query, there must be correct number of matches, all doc id are correct and scores must be desc + // ordered + assertEquals(4, topFieldDoc.totalHits.value); + ScoreDoc[] scoreDocs = topFieldDoc.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(4, scoreDocs.length); + assertTrue(IntStream.range(0, scoreDocs.length - 1).noneMatch(idx -> scoreDocs[idx].doc > scoreDocs[idx + 1].doc)); + List resultDocIds = Arrays.stream(scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); + assertTrue(Arrays.stream(docIdsForQuery).allMatch(resultDocIds::contains)); + } + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testPagingFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + List documents = new ArrayList<>(); + Document document1 = new Document(); + document1.add(new NumericDocValuesField("_id", DOC_ID_1)); + document1.add(new IntField(INT_FIELD_NAME, 100, Field.Store.YES)); + document1.add(new TextField(TEXT_FIELD_NAME, FIELD_1_VALUE, Field.Store.YES)); + documents.add(document1); + Document document2 = new Document(); + document2.add(new NumericDocValuesField("_id", DOC_ID_2)); + document2.add(new IntField(INT_FIELD_NAME, 200, Field.Store.YES)); + document2.add(new TextField(TEXT_FIELD_NAME, FIELD_2_VALUE, Field.Store.YES)); + documents.add(document2); + Document document3 = new Document(); + document3.add(new NumericDocValuesField("_id", DOC_ID_3)); + document3.add(new IntField(INT_FIELD_NAME, 300, Field.Store.YES)); + document3.add(new TextField(TEXT_FIELD_NAME, FIELD_3_VALUE, Field.Store.YES)); + documents.add(document3); + Document document4 = new Document(); + document4.add(new NumericDocValuesField("_id", DOC_ID_4)); + document4.add(new IntField(INT_FIELD_NAME, 400, Field.Store.YES)); + document4.add(new TextField(TEXT_FIELD_NAME, FIELD_4_VALUE, Field.Store.YES)); + documents.add(document4); + w.addDocuments(documents); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + SortField sortField = new SortField(DOC_FIELD_NAME, SortField.Type.DOC); + HybridTopFieldDocSortCollector hybridTopFieldDocSortCollector = new PagingFieldCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO), + new Sort(sortField), + new FieldDoc(Integer.MAX_VALUE, 0.0f, new Object[] { DOC_ID_2 }) + ); + Weight weight = mock(Weight.class); + hybridTopFieldDocSortCollector.setWeight(weight); + LeafCollector leafCollector = hybridTopFieldDocSortCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + int[] docIdsForQuery = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3, DOC_ID_4 }; + Arrays.sort(docIdsForQuery); + int indexPositionOfDocId2 = Arrays.binarySearch(docIdsForQuery, DOC_ID_2); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)))) + ); + + leafCollector.setScorer(hybridQueryScorer); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + + int doc = iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + leafCollector.collect(doc); + doc = iterator.nextDoc(); + } + + List topFieldDocs = hybridTopFieldDocSortCollector.topDocs(); + + assertNotNull(topFieldDocs); + assertEquals(1, topFieldDocs.size()); + for (TopFieldDocs topFieldDoc : topFieldDocs) { + // assert results for each sub-query, there must be correct number of matches, all doc id are correct and scores must be desc + // ordered + assertEquals(4 - (indexPositionOfDocId2 + 1), topFieldDoc.totalHits.value); + ScoreDoc[] scoreDocs = topFieldDoc.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(4 - (indexPositionOfDocId2 + 1), scoreDocs.length); + assertTrue(IntStream.range(0, scoreDocs.length - 1).noneMatch(idx -> scoreDocs[idx].doc > scoreDocs[idx + 1].doc)); + List resultDocIds = Arrays.stream(scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); + List docIdsByQueryList = Arrays.stream(docIdsForQuery).boxed().collect(Collectors.toList()); + resultDocIds.stream().forEach(val -> assertTrue(docIdsByQueryList.contains(val))); + } + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index 351ec680c..1fb66d5b7 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -46,6 +46,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase { diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 40d2ee3f6..de9c6006b 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.search.query; import com.carrotsearch.randomizedtesting.RandomizedTest; +import java.util.Arrays; import lombok.SneakyThrows; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -16,10 +17,13 @@ import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.Query; @@ -37,7 +41,10 @@ import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryWeight; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; +import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; +import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; @@ -52,6 +59,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import org.opensearch.search.sort.SortAndFormats; public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { @@ -311,6 +319,171 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { directory.close(); } + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + SortField sortField = new SortField("_doc", SortField.Type.DOC); + Sort sort = new Sort(sortField); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + when(searchContext.sort()).thenReturn(new SortAndFormats(sort, docValueFormat)); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof SimpleFieldCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreApplied_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + SortField sortField = new SortField("_doc", SortField.Type.DOC); + Sort sort = new Sort(sortField); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + FieldDoc after = new FieldDoc(Integer.MAX_VALUE, 0.0f, new Object[] { 1 }, -1); + when(searchContext.sort()).thenReturn(new SortAndFormats(sort, docValueFormat)); + when(searchContext.searchAfter()).thenReturn(after); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof PagingFieldCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + SortField sortField = new SortField("_doc", SortField.Type.DOC); + Sort sort = new Sort(sortField); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + when(searchContext.sort()).thenReturn(new SortAndFormats(sort, docValueFormat)); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + SimpleFieldCollector simpleFieldCollector = (SimpleFieldCollector) hybridCollectorManager.newCollector(); + + FieldDoc after = new FieldDoc(Integer.MAX_VALUE, 0.0f, new Object[] { docId1 }, -1); + when(searchContext.searchAfter()).thenReturn(after); + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + PagingFieldCollector pagingFieldCollector = (PagingFieldCollector) hybridCollectorManager1.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithMatchAll, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + simpleFieldCollector.setWeight(weight); + pagingFieldCollector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = simpleFieldCollector.getLeafCollector(leafReaderContext); + LeafCollector leafCollector1 = pagingFieldCollector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + BulkScorer scorer1 = weight.bulkScorer(leafReaderContext); + scorer1.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + + Object results = hybridCollectorManager.reduce(List.of()); + Object results1 = hybridCollectorManager1.reduce(List.of()); + + assertNotNull(results); + assertNotNull(results1); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(3, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(4, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[2].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } + @SneakyThrows public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocs_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); @@ -436,4 +609,129 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD reader2.close(); directory2.close(); } + + @SneakyThrows + public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocsWithSort_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("id", SortField.Type.DOC); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + when(searchContext.sort()).thenReturn(sortAndFormats); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int[] docIds = new int[] { docId1, docId2, docId3 }; + Arrays.sort(docIds); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docIds[0], TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docIds[1], TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docIds[2], TEST_DOC_TEXT2, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + SimpleFieldCollector collector1 = (SimpleFieldCollector) hybridCollectorManager.newCollector(); + SimpleFieldCollector collector2 = (SimpleFieldCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(3, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + FieldDoc[] fieldDocs = (FieldDoc[]) topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(5, fieldDocs.length); + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertEquals(fieldDocs[2].doc, fieldDocs[2].fields[0]); + assertEquals(fieldDocs[3].doc, fieldDocs[3].fields[0]); + assertEquals(1, fieldDocs[4].fields[0]); + + w.close(); + reader.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 2147578c9..196014220 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -4,14 +4,20 @@ */ package org.opensearch.neuralsearch.search.query; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import static org.opensearch.neuralsearch.search.query.TopDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.sort.SortAndFormats; public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { @@ -19,7 +25,7 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scores = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), createDelimiterElementForHybridSearchResults(2), @@ -28,31 +34,33 @@ public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail( NullPointerException exception = assertThrows( NullPointerException.class, - () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(scores, null, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("score docs cannot be null", exception.getMessage()); - exception = assertThrows(NullPointerException.class, () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR)); + exception = assertThrows( + NullPointerException.class, + () -> scoreDocsMerger.merge(scores, null, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) + ); assertEquals("score docs cannot be null", exception.getMessage()); ScoreDoc[] lessElementsScoreDocs = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), new ScoreDoc(1, 0.7f) }; IllegalArgumentException notEnoughException = assertThrows( IllegalArgumentException.class, - () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); notEnoughException = assertThrows( IllegalArgumentException.class, - () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); } public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { createStartStopElementForHybridSearchResults(0), createDelimiterElementForHybridSearchResults(0), @@ -71,7 +79,13 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { new ScoreDoc(4, 0.6f), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(10, mergedScoreDocs.length); @@ -91,7 +105,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { createStartStopElementForHybridSearchResults(0), createDelimiterElementForHybridSearchResults(0), @@ -107,7 +121,12 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf new ScoreDoc(4, 0.6f), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(8, mergedScoreDocs.length); @@ -124,7 +143,7 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { createStartStopElementForHybridSearchResults(0), createDelimiterElementForHybridSearchResults(0), @@ -136,7 +155,12 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { createDelimiterElementForHybridSearchResults(2), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(4, mergedScoreDocs.length); @@ -147,8 +171,183 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[3].score, 0); } + public void testIncorrectInput_whenFieldDocsAreNullOrNotEnoughElements_thenFail() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] scores = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + NullPointerException exception = assertThrows( + NullPointerException.class, + () -> fieldDocsMerger.merge(scores, null, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + exception = assertThrows( + NullPointerException.class, + () -> fieldDocsMerger.merge(scores, null, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + FieldDoc[] lessElementsScoreDocs = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }) }; + + IllegalArgumentException notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> fieldDocsMerger.merge(lessElementsScoreDocs, scores, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + + notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> fieldDocsMerger.merge(scores, lessElementsScoreDocs, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + } + + public void testMergeFieldDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 80 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 10 }), + new FieldDoc(4, 0.3f, new Object[] { 5 }), + new FieldDoc(5, 0.05f, new Object[] { 2 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 5 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(10, mergedFieldDocs.length); + + // check format, all elements one by one + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertFieldDoc(mergedFieldDocs[2], 0, 100); + assertFieldDoc(mergedFieldDocs[3], 2, 80); + assertFieldDoc(mergedFieldDocs[4], 1, 10); + assertFieldDoc(mergedFieldDocs[5], 4, 5); + assertFieldDoc(mergedFieldDocs[6], 5, 2); + assertEquals(1, mergedFieldDocs[7].fields[0]); + assertFieldDoc(mergedFieldDocs[8], 4, 5); + assertEquals(1, mergedFieldDocs[9].fields[0]); + } + + public void testMergeFieldDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + new FieldDoc(4, 0.3f, new Object[] { 80 }), + new FieldDoc(5, 0.05f, new Object[] { 20 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 50 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(8, mergedFieldDocs.length); + + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertFieldDoc(mergedFieldDocs[2], 1, 100); + assertFieldDoc(mergedFieldDocs[3], 4, 80); + assertFieldDoc(mergedFieldDocs[4], 5, 20); + assertEquals(1, mergedFieldDocs[5].fields[0]); + assertFieldDoc(mergedFieldDocs[6], 4, 50); + assertEquals(1, mergedFieldDocs[7].fields[0]); + } + + public void testMergeFieldDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(4, mergedFieldDocs.length); + // check format, all elements one by one + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertEquals(1, mergedFieldDocs[2].fields[0]); + assertEquals(1, mergedFieldDocs[3].fields[0]); + } + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { assertEquals(expectedDocId, scoreDoc.doc); assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); } + + private void assertFieldDoc(FieldDoc fieldDoc, int expectedDocId, int expectedSortValue) { + assertEquals(expectedDocId, fieldDoc.doc); + assertEquals(expectedSortValue, fieldDoc.fields[0]); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 5a99f3f3a..d10ca0668 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -5,8 +5,12 @@ package org.opensearch.neuralsearch.search.query; import lombok.SneakyThrows; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; @@ -15,6 +19,11 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; + +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.sort.SortAndFormats; public class TopDocsMergerTests extends OpenSearchQueryTestCase { @@ -22,8 +31,7 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -78,8 +86,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { @SneakyThrows public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -130,8 +137,7 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -172,8 +178,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { @SneakyThrows public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -248,8 +253,260 @@ public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[12].score, 0); } + @SneakyThrows + public void testMergeFieldDocs_whenBothTopDocsHasHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 80 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(1, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(1, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 70 }), + new FieldDoc(4, 0.3f, new Object[] { 60 }), + new FieldDoc(5, 0.05f, new Object[] { 30 }), + createFieldDocDelimiterElementForHybridSearchResults(1, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 40 }), + createFieldDocStartStopElementForHybridSearchResults(1, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(6, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 5 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 5 + 1 + 2 + 2 = 10 + assertEquals(10, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 0, 100); + assertFieldDoc(fieldDocs[3], 2, 80); + assertFieldDoc(fieldDocs[4], 1, 70); + assertFieldDoc(fieldDocs[5], 4, 60); + assertFieldDoc(fieldDocs[6], 5, 30); + assertEquals(1, fieldDocs[7].fields[0]); + assertFieldDoc(fieldDocs[8], 4, 40); + assertEquals(1, fieldDocs[9].fields[0]); + } + + @SneakyThrows + public void testMergeFieldDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + new FieldDoc(4, 0.3f, new Object[] { 60 }), + new FieldDoc(5, 0.05f, new Object[] { 30 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 80 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 3 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 3 + 1 + 2 + 2 = 8 + assertEquals(8, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 1, 100); + assertFieldDoc(fieldDocs[3], 4, 60); + assertFieldDoc(fieldDocs[4], 5, 30); + assertEquals(1, fieldDocs[5].fields[0]); + assertFieldDoc(fieldDocs[6], 4, 80); + assertEquals(1, fieldDocs[7].fields[0]); + } + + @SneakyThrows + public void testMergeFieldDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(0, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertEquals(1, fieldDocs[2].fields[0]); + assertEquals(1, fieldDocs[3].fields[0]); + } + + @SneakyThrows + public void testThreeSequentialMergesWithFieldDocs_whenAllTopDocsHasHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 20 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 80 }), + new FieldDoc(4, 0.3f, new Object[] { 30 }), + new FieldDoc(5, 0.05f, new Object[] { 10 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 30 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(firstMergedTopDocsAndMaxScore); + + // merge results from collector 3 + TopDocs topDocsThirdCollector = new TopFieldDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(3, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(3, new Object[] { 1 }), + new FieldDoc(3, 0.4f, new Object[] { 90 }), + createFieldDocDelimiterElementForHybridSearchResults(3, new Object[] { 1 }), + new FieldDoc(7, 0.85f, new Object[] { 60 }), + new FieldDoc(9, 0.2f, new Object[] { 50 }), + createFieldDocStartStopElementForHybridSearchResults(3, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.merge( + firstMergedTopDocsAndMaxScore, + topDocsAndMaxScoreThirdCollector + ); + + assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(9, finalMergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, finalMergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 6 from sub-query1 and 3 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 6 + 3 + 2 + 2 = 13 + assertEquals(13, finalMergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) finalMergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 0, 100); + assertFieldDoc(fieldDocs[3], 3, 90); + assertFieldDoc(fieldDocs[4], 1, 80); + assertFieldDoc(fieldDocs[5], 4, 30); + assertFieldDoc(fieldDocs[6], 2, 20); + assertFieldDoc(fieldDocs[7], 5, 10); + assertEquals(1, fieldDocs[8].fields[0]); + assertFieldDoc(fieldDocs[9], 7, 60); + assertFieldDoc(fieldDocs[10], 9, 50); + assertFieldDoc(fieldDocs[11], 4, 30); + assertEquals(1, fieldDocs[12].fields[0]); + } + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { assertEquals(expectedDocId, scoreDoc.doc); assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); } + + private void assertFieldDoc(FieldDoc fieldDoc, int expectedDocId, int expectedSortValue) { + assertEquals(expectedDocId, fieldDoc.doc); + assertEquals(expectedSortValue, fieldDoc.fields[0]); + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 1fd3b47c7..966a06f49 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,6 +49,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; +import org.opensearch.search.sort.SortBuilder; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -512,7 +513,7 @@ protected Map search( Map requestParams, List aggs ) { - return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null); + return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null, null, false, null); } @SneakyThrows @@ -523,7 +524,10 @@ protected Map search( int resultSize, Map requestParams, List aggs, - QueryBuilder postFilterBuilder + QueryBuilder postFilterBuilder, + List> sortBuilders, + boolean trackScores, + List searchAfter ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -548,21 +552,39 @@ protected Map search( builder.field("post_filter"); postFilterBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); } + if (Objects.nonNull(sortBuilders) && !sortBuilders.isEmpty()) { + builder.startArray("sort"); + for (SortBuilder sortBuilder : sortBuilders) { + sortBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } + builder.endArray(); + } + + if (trackScores) { + builder.field("track_scores", trackScores); + } + if (searchAfter != null && !searchAfter.isEmpty()) { + builder.startArray("search_after"); + for (Object searchAfterEntry : searchAfter) { + builder.value(searchAfterEntry); + } + builder.endArray(); + } builder.endObject(); - Request request = new Request("POST", "/" + index + "/_search"); + Request request = new Request("GET", "/" + index + "/_search?timeout=1000s"); request.addParameter("size", Integer.toString(resultSize)); if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } + logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); - Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - + logger.info("Response " + responseBody); return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index 0534f85bf..bc016aae2 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -325,6 +325,28 @@ public static void assertHitResultsFromQuery(int expected, Map s assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + public static void assertHitResultsFromQueryWhenSortIsEnabled( + int expectedCollectedHits, + int expectedTotalHits, + Map searchResponseAsMap + ) { + assertEquals(expectedCollectedHits, getHitCount(searchResponseAsMap)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + } + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expectedTotalHits, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + private static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits");