diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 2699d2eb3..a4e39f448 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -96,11 +96,12 @@ private void combineShardScores( // - sort documents by scores and take first "max number" of docs // create a collection of doc ids that are sorted by their combined scores - Collection sortedDocsIds = getSortedDocIds( - combinedNormalizedScoresByDocId, - getTopFieldDocs(sort, topDocsPerSubQuery), - sort - ); + Collection sortedDocsIds; + if (sort != null) { + sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); + } else { + sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + } // - update query search results with normalized scores updateQueryTopDocsWithCombinedScores( @@ -183,23 +184,21 @@ private Map getDocIdSortFieldsMap( return docIdSortFieldMap; } - private Collection getSortedDocIds( - final Map combinedNormalizedScoresByDocId, - final List topFieldDocs, - final Sort sort - ) { + private List getSortedDocIds(final Map combinedNormalizedScoresByDocId) { // we're merging docs with normalized and combined scores. we need to have only maxHits results - if (sort == null) { - List sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet()); - sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))); - return sortedDocsIds; - } + List sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet()); + sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))); + return sortedDocsIds; + } + + private Set getSortedDocIdsBySortCriteria(final List topFieldDocs, final Sort sort) { if (Objects.isNull(topFieldDocs)) { throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is enabled."); } - int topN = 0; + // size will be equal to the number of score docs + int size = 0; for (TopFieldDocs topFieldDoc : topFieldDocs) { - topN += topFieldDoc.scoreDocs.length; + size += topFieldDoc.scoreDocs.length; } // Merge the sorted results of individual queries to form a one final result per shard which is sorted. @@ -214,7 +213,7 @@ private Collection getSortedDocIds( // < 0, 0.7, shardId, [90]> // < 1, 0.7, shardId, [70]> // < 1, 0.3, shardId, [70]> - final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, topN, topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER); + final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, size, topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER); // Remove duplicates from the sorted top docs. Set uniqueDocIds = new LinkedHashSet<>(); diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java index bf92fc75d..fb7ca53a6 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchSortUtil.java @@ -78,9 +78,6 @@ private static boolean containsTopFieldDocs(List topDocs) { * More details here https://github.com/opensearch-project/OpenSearch/issues/6326 */ private static Sort createSort(TopFieldDocs[] topFieldDocs) { - if (topFieldDocs == null || topFieldDocs[0] == null) { - throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is applied"); - } final SortField[] firstTopDocFields = topFieldDocs[0].fields; final SortField[] newFields = new SortField[firstTopDocFields.length];