Skip to content

Commit

Permalink
Address Vijay Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Jul 2, 2024
1 parent 16ca3d5 commit 99122af
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ private void combineShardScores(

// - sort documents by scores and take first "max number" of docs
// create a collection of doc ids that are sorted by their combined scores
Collection<Integer> sortedDocsIds = getSortedDocIds(
combinedNormalizedScoresByDocId,
getTopFieldDocs(sort, topDocsPerSubQuery),
sort
);
Collection<Integer> sortedDocsIds;
if (sort != null) {
sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort);
} else {
sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId);
}

// - update query search results with normalized scores
updateQueryTopDocsWithCombinedScores(
Expand Down Expand Up @@ -183,23 +184,21 @@ private Map<Integer, Object[]> getDocIdSortFieldsMap(
return docIdSortFieldMap;
}

private Collection<Integer> getSortedDocIds(
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<TopFieldDocs> topFieldDocs,
final Sort sort
) {
private List<Integer> getSortedDocIds(final Map<Integer, Float> combinedNormalizedScoresByDocId) {
// we're merging docs with normalized and combined scores. we need to have only maxHits results
if (sort == null) {
List<Integer> sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet());
sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a)));
return sortedDocsIds;
}
List<Integer> sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet());
sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a)));
return sortedDocsIds;
}

private Set<Integer> getSortedDocIdsBySortCriteria(final List<TopFieldDocs> topFieldDocs, final Sort sort) {
if (Objects.isNull(topFieldDocs)) {
throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is enabled.");
}
int topN = 0;
// size will be equal to the number of score docs
int size = 0;
for (TopFieldDocs topFieldDoc : topFieldDocs) {
topN += topFieldDoc.scoreDocs.length;
size += topFieldDoc.scoreDocs.length;
}

// Merge the sorted results of individual queries to form a one final result per shard which is sorted.
Expand All @@ -214,7 +213,7 @@ private Collection<Integer> getSortedDocIds(
// < 0, 0.7, shardId, [90]>
// < 1, 0.7, shardId, [70]>
// < 1, 0.3, shardId, [70]>
final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, topN, topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER);
final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, size, topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER);

// Remove duplicates from the sorted top docs.
Set<Integer> uniqueDocIds = new LinkedHashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ private static boolean containsTopFieldDocs(List<TopDocs> topDocs) {
* More details here https://github.com/opensearch-project/OpenSearch/issues/6326
*/
private static Sort createSort(TopFieldDocs[] topFieldDocs) {
if (topFieldDocs == null || topFieldDocs[0] == null) {
throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is applied");
}
final SortField[] firstTopDocFields = topFieldDocs[0].fields;
final SortField[] newFields = new SortField[firstTopDocFields.length];

Expand Down

0 comments on commit 99122af

Please sign in to comment.