Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reciprocal Rank Fusion (RRF) normalization technique in hybrid query #874

Open
wants to merge 5 commits into
base: feature/rrf-score-normalization-v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.16...2.x)
### Features
- Implement Reciprocal Rank Fusion document score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,22 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.RRFProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand Down Expand Up @@ -154,7 +156,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory),
RRFProcessor.TYPE,
new RRFProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data required for score normalization passed to execute() function
* in NormalizationProcessorWorkflow. Field rankConstant
* used in RRF but will be null for other normalization techniques
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDTO {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
@Nullable
private int rankConstant;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rank constant should be a parameter for RRF normalization technique class, not the high level field in this DTO. You can pass it as param to RRF normalization technique constructor while creating its instance in a factory class

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
// Builds data transfer object to pass into execute, DTO has nullable field for rankConstant which
// is only used for RRF technique
NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,33 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param normalizationExecuteDTO contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) {
final List<QuerySearchResult> querySearchResults = normalizationExecuteDTO.getQuerySearchResults();
final Optional<FetchSearchResult> fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional();
final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique();
final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique();
final int rankConstant = normalizationExecuteDTO.getRankConstant();
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(queryTopDocs)
.normalizationTechnique(normalizationTechnique)
.rankConstant(rankConstant)
.build();

// normalize
log.debug("Do score normalization");
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);
scoreNormalizer.normalizeScores(normalizeScoresDTO);

CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import java.util.List;

/**
* DTO object to hold data required for score normalization.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizeScoresDTO {
@NonNull
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@Nullable
private int rankConstant;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.stream.Collectors;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
* Processor for implementing reciprocal rank fusion technique on post
* query search results. Updates query results with
* normalized and combined scores for next phase (typically it's FETCH)
* by using ranks from individual subqueries to calculate 'normalized'
* scores before combining results from subqueries into final results
*/
@Log4j2
@AllArgsConstructor
public class RRFProcessor implements SearchPhaseResultsProcessor {
public static final String TYPE = "score-ranker-processor";

private final String tag;
private final String description;
private final ScoreNormalizationTechnique normalizationTechnique;
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;
private final int rankConstant;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with RRF processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending
// on coming from NormalizationProcessor or RRFProcessor
NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.rankConstant(rankConstant)
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}

@Override
public SearchPhaseName getBeforePhase() {
return SearchPhaseName.QUERY;
}

@Override
public SearchPhaseName getAfterPhase() {
return SearchPhaseName.FETCH;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getTag() {
return tag;
}

@Override
public String getDescription() {
return description;
}

@Override
public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) {
return true;
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;

import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import java.util.Map;

@Log4j2
/**
* Abstracts combination of scores based on reciprocal rank fusion algorithm
*/
@ToString(onlyExplicitlyIncluded = true)

public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";

// Not currently using weights for RRF, no need to modify or verify these params
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
;
}

@Override
public float combine(final float[] scores) {
float sumScores = 0.0f;
for (float score : scores) {
sumScores += score;
}
return sumScores;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ public class ScoreCombinationFactory {
Map.of(),
scoreCombinationUtil
);
public static final ScoreCombinationTechnique RRF_METHOD = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil);

private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil),
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil),
RRFScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
Loading
Loading