Skip to content

Commit

Permalink
Enable sorting and search_after features in Hybrid Search
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Jul 9, 2024
2 parents fb1f1fd + a0c82c6 commit f4307bb
Show file tree
Hide file tree
Showing 35 changed files with 3,520 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
*/
package org.opensearch.neuralsearch.processor;

import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;

import lombok.AllArgsConstructor;
import lombok.Getter;
Expand All @@ -39,14 +38,14 @@ public class CompoundTopDocs {
@Setter
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
initialize(totalHits, topDocs);
public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs, final boolean isSortEnabled) {
initialize(totalHits, topDocs, isSortEnabled);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs);
scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled);
}

/**
Expand Down Expand Up @@ -74,9 +73,13 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
}
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>());
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled);
return;
}
// skipping first two elements, it's a start-stop element and delimiter for first series
Expand All @@ -88,17 +91,22 @@ public CompoundTopDocs(final TopDocs topDocs) {
if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) {
ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO);
TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
TopDocs subQueryTopDocs;
if (isSortEnabled) {
subQueryTopDocs = new TopFieldDocs(totalHits, subQueryScores, ((TopFieldDocs) topDocs).fields);
} else {
subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
}
topDocsList.add(subQueryTopDocs);
scoreDocList.clear();
} else {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList);
initialize(topDocs.totalHits, topDocsList, isSortEnabled);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs, boolean isSortEnabled) {
if (docs == null) {
return null;
}
Expand All @@ -113,7 +121,20 @@ private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
maxScoreDocs = topDoc.scoreDocs;
}
}

// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList());
List<ScoreDoc> scoreDocs = new ArrayList<>();
for (ScoreDoc scoreDoc : maxScoreDocs) {
scoreDocs.add(deepCopyScoreDoc(scoreDoc, isSortEnabled));
}
return scoreDocs;
}

private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortEnabled) {
if (!isSortEnabled) {
return new ScoreDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
}
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
Expand All @@ -27,6 +31,8 @@

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;

/**
* Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data
Expand Down Expand Up @@ -62,13 +68,20 @@ public void execute(
log.debug("Do score normalization");
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
.scoreCombinationTechnique(combinationTechnique)
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.build();

// combine
log.debug("Do score combination");
scoreCombiner.combineScores(queryTopDocs, combinationTechnique);
scoreCombiner.combineScores(combineScoresDTO);

// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
}

Expand Down Expand Up @@ -96,7 +109,23 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) {
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
final List<CompoundTopDocs> queryTopDocs = combineScoresDTO.getQueryTopDocs();
if (querySearchResults.size() != queryTopDocs.size()) {
throw new IllegalStateException(
String.format(
Expand All @@ -107,17 +136,42 @@ private void updateOriginalQueryResults(final List<QuerySearchResult> querySearc
)
);
}
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f;
return queryTopDocs;
}

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));
/**
* Get Max score on Shard
* @param updatedTopDocs updatedTopDocs compound top docs on a shard
* @param isSortEnabled if sort is enabled or disabled
* @return max score
*/
private float maxScoreForShard(CompoundTopDocs updatedTopDocs, boolean isSortEnabled) {
if (updatedTopDocs.getTotalHits().value == 0 || updatedTopDocs.getScoreDocs().isEmpty()) {
return MAX_SCORE_WHEN_NO_HITS_FOUND;
}
if (isSortEnabled) {
float maxScore = MAX_SCORE_WHEN_NO_HITS_FOUND;
// In case of sorting iterate over score docs and deduce the max score
for (ScoreDoc scoreDoc : updatedTopDocs.getScoreDocs()) {
maxScore = Math.max(maxScore, scoreDoc.score);
}
return maxScore;
}
// If it is a normal hybrid query then first entry of score doc will have max score
return updatedTopDocs.getScoreDocs().get(0).score;
}

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
/**
* Get Top Docs on Shard
* @param updatedTopDocs compound top docs on a shard
* @param sort sort criteria
* @return TopDocs which will be instance of TopFieldDocs if sort is enabled.
*/
private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
if (sort != null) {
return new TopFieldDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new FieldDoc[0]), sort.getSort());
}
return new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.search.Sort;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.search.query.QuerySearchResult;

/**
* DTO object to hold data required for Score Combination.
*/
@AllArgsConstructor
@Builder
@Getter
public class CombineScoresDto {
@NonNull
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreCombinationTechnique scoreCombinationTechnique;
@NonNull
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
}
Loading

0 comments on commit f4307bb

Please sign in to comment.