From 9b0a28a03c1ff0c7a7cfe9bfd638412509aa8adc Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 13 Jun 2023 12:24:13 -0700 Subject: [PATCH] Adding search processor for score normalization and combination Signed-off-by: Martin Gaievski --- build.gradle | 6 +- .../neuralsearch/plugin/NeuralSearch.java | 18 +- .../processor/NormalizationProcessor.java | 204 ++++++++++ .../processor/ScoreCombinationTechnique.java | 43 ++ .../neuralsearch/processor/ScoreCombiner.java | 103 +++++ .../ScoreNormalizationTechnique.java | 57 +++ .../processor/ScoreNormalizer.java | 78 ++++ .../NormalizationProcessorFactory.java | 54 +++ .../common/BaseNeuralSearchIT.java | 111 +++++- .../plugin/NeuralSearchTests.java | 18 + .../NormalizationProcessorTests.java | 213 ++++++++++ .../processor/ScoreCombinerTests.java | 108 +++++ .../ScoreNormalizationCombinationIT.java | 372 ++++++++++++++++++ .../processor/ScoreNormalizerTests.java | 226 +++++++++++ .../NormalizationProcessorFactoryTests.java | 69 ++++ 15 files changed, 1675 insertions(+), 5 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java diff --git a/build.gradle b/build.gradle index 6e4b4ada4..541d70354 100644 --- a/build.gradle +++ b/build.gradle @@ -253,8 +253,12 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx1g") - // enable hybrid search for testing + + // enable features for testing + // hybrid search systemProperty('neural_search_hybrid_search_enabled', 'true') + // search pipelines + systemProperty('opensearch.experimental.feature.search_pipeline.enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 83f8c396b..de55c2b25 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -13,6 +13,8 @@ import java.util.Optional; import java.util.function.Supplier; +import lombok.extern.log4j.Log4j2; + import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -24,7 +26,9 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -33,9 +37,11 @@ import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -45,7 +51,8 @@ /** * Neural Search plugin class */ -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin { +@Log4j2 +public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { /** * Gates the functionality of hybrid search * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. @@ -90,9 +97,18 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { + log.info("Registering hybrid query phase searcher"); return Optional.of(new HybridQueryPhaseSearcher()); } + log.info("Not registering hybrid query phase searcher because feature flag is disabled"); // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core return Optional.empty(); } + + @Override + public Map> getSearchPhaseResultsProcessors( + Parameters parameters + ) { + return Map.of(NormalizationProcessor.TYPE, new NormalizationProcessorFactory()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java new file mode 100644 index 000000000..45aa5bde8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -0,0 +1,204 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import joptsimple.internal.Strings; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.EnumUtils; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.query.QuerySearchResult; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Processor for score normalization and combination on post query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + */ +@Log4j2 +public class NormalizationProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "normalization-processor"; + public static final String NORMALIZATION_CLAUSE = "normalization"; + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + + private final String tag; + private final String description; + @VisibleForTesting + @Getter(AccessLevel.PACKAGE) + final ScoreNormalizationTechnique normalizationTechnique; + @Getter(AccessLevel.PACKAGE) + final ScoreCombinationTechnique combinationTechnique; + + /** + * Need all args constructor to validate parameters and fail fast + * @param tag + * @param description + * @param normalizationTechnique + * @param combinationTechnique + */ + public NormalizationProcessor( + final String tag, + final String description, + final String normalizationTechnique, + final String combinationTechnique + ) { + this.tag = tag; + this.description = description; + validateParameters(normalizationTechnique, combinationTechnique); + this.normalizationTechnique = ScoreNormalizationTechnique.valueOf(normalizationTechnique); + this.combinationTechnique = ScoreCombinationTechnique.valueOf(combinationTechnique); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (searchPhaseResult instanceof QueryPhaseResultConsumer) { + QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; + Optional maybeResult = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .filter(Objects::nonNull) + .findFirst(); + if (isNotHybridQuery(maybeResult)) { + return; + } + + TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsForResult(searchPhaseResult); + CompoundTopDocs[] queryTopDocs = Arrays.stream(topDocsAndMaxScores) + .map(td -> td != null ? (CompoundTopDocs) td.topDocs : null) + .collect(Collectors.toList()) + .toArray(CompoundTopDocs[]::new); + + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + + ScoreCombiner scoreCombinator = new ScoreCombiner(); + scoreCombinator.combineScores(topDocsAndMaxScores, queryTopDocs, combinationTechnique); + + updateOriginalQueryResults(searchPhaseResult, queryTopDocs); + } + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return true; + } + + protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName) { + if (Strings.isNullOrEmpty(normalizationTechniqueName)) { + throw new IllegalArgumentException("normalization technique cannot be empty"); + } + if (Strings.isNullOrEmpty(combinationTechniqueName)) { + throw new IllegalArgumentException("combination technique cannot be empty"); + } + if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) { + log.error(String.format(Locale.ROOT, "provided normalization technique [%s] is not supported", normalizationTechniqueName)); + throw new IllegalArgumentException("provided normalization technique is not supported"); + } + if (!EnumUtils.isValidEnum(ScoreCombinationTechnique.class, combinationTechniqueName)) { + log.error(String.format(Locale.ROOT, "provided combination technique [%s] is not supported", combinationTechniqueName)); + throw new IllegalArgumentException("provided combination technique is not supported"); + } + } + + private boolean isNotHybridQuery(final Optional maybeResult) { + return maybeResult.isEmpty() + || Objects.isNull(maybeResult.get().queryResult()) + || !(maybeResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs); + } + + private TopDocsAndMaxScore[] getCompoundQueryTopDocsForResult( + final SearchPhaseResults results + ) { + List preShardResultList = results.getAtomicArray().asList(); + TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[results.getAtomicArray().length()]; + int idx = 0; + for (Result shardResult : preShardResultList) { + if (shardResult == null) { + idx++; + continue; + } + QuerySearchResult querySearchResult = shardResult.queryResult(); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) { + idx++; + continue; + } + result[idx++] = topDocsAndMaxScore; + } + return result; + } + + @VisibleForTesting + protected void updateOriginalQueryResults( + final SearchPhaseResults results, + final CompoundTopDocs[] queryTopDocs + ) { + List preShardResultList = results.getAtomicArray().asList(); + for (int i = 0; i < preShardResultList.size(); i++) { + QuerySearchResult querySearchResult = preShardResultList.get(i).queryResult(); + CompoundTopDocs updatedTopDocs = queryTopDocs[i]; + if (updatedTopDocs == null) { + continue; + } + float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f; + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore); + querySearchResult.topDocs(topDocsAndMaxScore, null); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechnique.java new file mode 100644 index 000000000..adf23cd43 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechnique.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +/** + * Collection of techniques for score combination + */ +public enum ScoreCombinationTechnique { + + /** + * Arithmetic mean method for combining scores. + * cscore = (score1 + score2 +...+ scoreN)/N + * + * Zero (0.0) scores are excluded from number of scores N + */ + ARITHMETIC_MEAN { + + @Override + public float combine(float[] scores) { + float combinedScore = 0.0f; + int count = 0; + for (float score : scores) { + if (score >= 0.0) { + combinedScore += score; + count++; + } + } + return combinedScore / count; + } + }; + + public static final ScoreCombinationTechnique DEFAULT = ARITHMETIC_MEAN; + + /** + * Defines combination function specific to this technique + * @param scores array of collected original scores + * @return combined score + */ + abstract float combine(float[] scores); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java new file mode 100644 index 000000000..c73467250 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts combination of scores in query search results. + */ +@Log4j2 +public class ScoreCombiner { + + /** + * Performs score combination based on input combination technique. Mutates input object by updating combined scores + * @param topDocsAndMaxScores high level result object that needs to be mutated by setting total hits and max score based on results of normalization + * @param queryTopDocs query results that need to be normalized, mutated by method execution + * @param combinationTechnique exact combination method that should be applied + */ + public void combineScores( + final TopDocsAndMaxScore[] topDocsAndMaxScores, + final CompoundTopDocs[] queryTopDocs, + final ScoreCombinationTechnique combinationTechnique + ) { + for (int i = 0; i < queryTopDocs.length; i++) { + CompoundTopDocs compoundQueryTopDocs = queryTopDocs[i]; + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex; + Map normalizedScoresPerDoc = new HashMap<>(); + int maxHits = 0; + TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO; + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs topDocs = topDocsPerSubQuery.get(j); + int hits = 0; + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + if (!normalizedScoresPerDoc.containsKey(scoreDoc.doc)) { + float[] scores = new float[topDocsPerSubQuery.size()]; + // we initialize with -1.0, as after normalization it's possible that score is 0.0 + Arrays.fill(scores, -1.0f); + normalizedScoresPerDoc.put(scoreDoc.doc, scores); + } + normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score; + hits++; + } + maxHits = Math.max(maxHits, hits); + } + if (topDocsPerSubQuery.stream() + .anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) { + totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } + Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> combinationTechnique.combine(entry.getValue()))); + // create priority queue, make it max heap by the score + PriorityQueue pq = new PriorityQueue<>( + (a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a)) + ); + // we're merging docs with normalized and combined scores. we need to have only maxHits results + for (int docId : normalizedScoresPerDoc.keySet()) { + pq.add(docId); + } + + ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; + float maxScore = combinedNormalizedScoresByDocId.get(pq.peek()); + + for (int j = 0; j < maxHits; j++) { + int docId = pq.poll(); + finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); + } + compoundQueryTopDocs.scoreDocs = finalScoreDocs; + compoundQueryTopDocs.totalHits = new TotalHits(maxHits, totalHits); + log.warn( + String.format( + Locale.ROOT, + "update top docs maxScore, original value: %f, updated value %f", + topDocsAndMaxScores[i].maxScore, + maxScore + ) + ); + topDocsAndMaxScores[i].maxScore = maxScore; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechnique.java new file mode 100644 index 000000000..f515168f7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechnique.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import lombok.Builder; +import lombok.Data; + +import com.google.common.primitives.Floats; + +/** + * Collection of techniques for score normalization + */ +public enum ScoreNormalizationTechnique { + + /** + * Min-max normalization method. + * nscore = (score - min_score)/(max_score - min_score) + */ + MIN_MAX { + @Override + float normalize(ScoreNormalizationRequest request) { + // edge case when there is only one score and min and max scores are same + if (Floats.compare(request.getMaxScore(), request.getMinScore()) == 0 + && Floats.compare(request.getMaxScore(), request.getScore()) == 0) { + return SINGLE_RESULT_SCORE; + } + float normalizedScore = (request.getScore() - request.getMinScore()) / (request.getMaxScore() - request.getMinScore()); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } + }; + + public static final ScoreNormalizationTechnique DEFAULT = MIN_MAX; + + static final float MIN_SCORE = 0.001f; + static final float SINGLE_RESULT_SCORE = 1.0f; + + /** + * Defines normalization function specific to this technique + * @param request complex request DTO that defines parameters like min/max scores etc. + * @return normalized score + */ + abstract float normalize(ScoreNormalizationRequest request); + + /** + * DTO for normalize method request + */ + @Data + @Builder + static class ScoreNormalizationRequest { + private final float score; + private final float minScore; + private final float maxScore; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java new file mode 100644 index 000000000..06d8b206a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores in query search results. + */ +@Log4j2 +public class ScoreNormalizer { + + /** + * Performs score normalization based on input combination technique. Mutates input object by updating normalized scores. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param normalizationTechnique exact normalization method that should be applied + */ + public void normalizeScores(final CompoundTopDocs[] queryTopDocs, final ScoreNormalizationTechnique normalizationTechnique) { + Optional maybeCompoundQuery = Arrays.stream(queryTopDocs) + .filter(topDocs -> Objects.nonNull(topDocs) && !topDocs.getCompoundTopDocs().isEmpty()) + .findAny(); + if (maybeCompoundQuery.isEmpty()) { + return; + } + + // init scores per sub-query + float[][] minMaxScores = new float[maybeCompoundQuery.get().getCompoundTopDocs().size()][]; + for (int i = 0; i < minMaxScores.length; i++) { + minMaxScores[i] = new float[] { Float.MAX_VALUE, Float.MIN_VALUE }; + } + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (compoundQueryTopDocs == null) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + // get min and max scores + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + minMaxScores[j][0] = Math.min(minMaxScores[j][0], scoreDoc.score); + minMaxScores[j][1] = Math.max(minMaxScores[j][1], scoreDoc.score); + } + } + } + // do the normalization + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (compoundQueryTopDocs == null) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + ScoreNormalizationTechnique.ScoreNormalizationRequest normalizationRequest = + ScoreNormalizationTechnique.ScoreNormalizationRequest.builder() + .score(scoreDoc.score) + .minScore(minMaxScores[j][0]) + .maxScore(minMaxScores[j][1]) + .build(); + scoreDoc.score = normalizationTechnique.normalize(normalizationRequest); + } + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java new file mode 100644 index 000000000..de726ddc5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; + +import java.util.Map; + +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +/** + * Factory for query results normalization processor for search pipeline. Instantiates processor based on user provided input. + */ +public class NormalizationProcessorFactory implements Processor.Factory { + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + Map normalizationClause = readOptionalMap( + NormalizationProcessor.TYPE, + tag, + config, + NormalizationProcessor.NORMALIZATION_CLAUSE + ); + String normalizationTechnique = normalizationClause == null + ? ScoreNormalizationTechnique.DEFAULT.name() + : (String) normalizationClause.get(NormalizationProcessor.TECHNIQUE); + + Map combinationClause = readOptionalMap( + NormalizationProcessor.TYPE, + tag, + config, + NormalizationProcessor.COMBINATION_CLAUSE + ); + String combinationTechnique = combinationClause == null + ? ScoreCombinationTechnique.DEFAULT.name() + : (String) combinationClause.get(NormalizationProcessor.TECHNIQUE); + + return new NormalizationProcessor(tag, description, normalizationTechnique, combinationTechnique); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index e1f907a5c..c6a3d722d 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -46,6 +46,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; +import org.opensearch.neuralsearch.processor.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.ScoreNormalizationTechnique; import com.google.common.collect.ImmutableList; @@ -281,6 +283,27 @@ protected Map search(String index, QueryBuilder queryBuilder, in */ @SneakyThrows protected Map search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) { + return search(index, queryBuilder, rescorer, resultSize, Map.of()); + } + + /** + * Execute a search request initialized from a neural query builder that can add a rescore query to the request + * + * @param index Index to search against + * @param queryBuilder queryBuilder to produce source of query + * @param rescorer used for rescorer query builder + * @param resultSize number of results to return in the search + * @param requestParams additional request params for search + * @return Search results represented as a map + */ + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams + ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -294,6 +317,9 @@ protected Map search(String index, QueryBuilder queryBuilder, Qu Request request = new Request("POST", "/" + index + "/_search"); request.addParameter("size", Integer.toString(resultSize)); + if (requestParams != null && !requestParams.isEmpty()) { + requestParams.forEach(request::addParameter); + } request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); @@ -386,7 +412,12 @@ protected int getHitCount(Map searchResponseAsMap) { */ @SneakyThrows protected void prepareKnnIndex(String indexName, List knnFieldConfigs) { - createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs), ""); + prepareKnnIndex(indexName, knnFieldConfigs, 3); + } + + @SneakyThrows + protected void prepareKnnIndex(String indexName, List knnFieldConfigs, int numOfShards) { + createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs, numOfShards), ""); } /** @@ -425,11 +456,11 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs) { + private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") - .field("number_of_shards", 3) + .field("number_of_shards", numberOfShards) .field("index.knn", true) .endObject() .startObject("mappings") @@ -524,4 +555,78 @@ protected void deleteModel(String modelId) { public boolean isUpdateClusterSettings() { return true; } + + @SneakyThrows + protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) { + createSearchPipeline( + pipelineId, + ScoreNormalizationTechnique.MIN_MAX.name(), + ScoreCombinationTechnique.ARITHMETIC_MEAN.name(), + Map.of() + ); + } + + @SneakyThrows + protected void createSearchPipeline( + final String pipelineId, + final String normalizationMethod, + String combinationMethod, + final Map combinationParams + ) { + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity( + String.format( + LOCALE, + "{\"description\": \"Post processor pipeline\"," + + "\"phase_results_processors\": [{ " + + "\"normalization-processor\": {" + + "\"normalization\": {" + + "\"technique\": \"%s\"" + + "}," + + "\"combination\": {" + + "\"technique\": \"%s\"" + + "}" + + "}}]}", + normalizationMethod, + combinationMethod + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void createSearchPipelineWithDefaultResultsPostProcessor(final String pipelineId) { + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity( + String.format( + LOCALE, + "{\"description\": \"Post processor pipeline\"," + + "\"phase_results_processors\": [{ " + + "\"normalization-processor\": {}}]}" + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void deleteSearchPipeline(final String pipelineId) { + makeRequest( + client(), + "DELETE", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index c4b1d49f7..7918126c5 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -12,12 +12,16 @@ import java.util.Optional; import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QueryPhaseSearcher; public class NeuralSearchTests extends OpenSearchQueryTestCase { @@ -55,4 +59,18 @@ public void testProcessors() { assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); } + + public void testSearchPhaseResultsProcessors() { + NeuralSearch plugin = new NeuralSearch(); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map> searchPhaseResultsProcessors = plugin + .getSearchPhaseResultsProcessors(parameters); + assertNotNull(searchPhaseResultsProcessors); + assertEquals(1, searchPhaseResultsProcessors.size()); + assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); + org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( + NormalizationProcessor.TYPE + ); + assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java new file mode 100644 index 000000000..0f129e424 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchProgressListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.breaker.CircuitBreaker; +import org.opensearch.common.breaker.NoopCircuitBreaker; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class NormalizationProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String INDEX_NAME = "index1"; + private SearchPhaseController searchPhaseController; + private ThreadPool threadPool; + private OpenSearchThreadPoolExecutor executor; + + @Before + public void setup() { + searchPhaseController = new SearchPhaseController(writableRegistry(), s -> new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + () -> PipelineAggregator.PipelineTree.EMPTY + ); + } + + public InternalAggregation.ReduceContext forFinalReduction() { + return InternalAggregation.ReduceContext.forFinalReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + b -> {}, + PipelineAggregator.PipelineTree.EMPTY + ); + }; + }); + threadPool = new TestThreadPool(NormalizationProcessorTests.class.getName()); + executor = OpenSearchExecutors.newFixed( + "test", + 1, + 10, + OpenSearchExecutors.daemonThreadFactory("test"), + threadPool.getThreadContext() + ); + } + + @After + public void cleanup() { + executor.shutdownNow(); + terminate(threadPool); + } + + public void testSearchResultTypes_whenNotCompoundDocsOrEmptyResults_thenNoProcessing() { + NormalizationProcessor normalizationProcessor = spy( + new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + ScoreNormalizationTechnique.MIN_MAX.name(), + ScoreCombinationTechnique.ARITHMETIC_MEAN.name() + ) + ); + + assertEquals(SearchPhaseName.FETCH, normalizationProcessor.getAfterPhase()); + assertEquals(SearchPhaseName.QUERY, normalizationProcessor.getBeforePhase()); + + SearchPhaseResults searchPhaseResults = mock(SearchPhaseResults.class); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(searchPhaseResults, searchPhaseContext); + + verify(normalizationProcessor, never()).updateOriginalQueryResults(any(), any()); + + AtomicArray resultAtomicArray = new AtomicArray<>(1); + when(searchPhaseResults.getAtomicArray()).thenReturn(resultAtomicArray); + normalizationProcessor.process(searchPhaseResults, searchPhaseContext); + + verify(normalizationProcessor, never()).updateOriginalQueryResults(any(), any()); + } + + public void testInputValidation_whenInvalidParameters_thenFail() { + expectThrows( + IllegalArgumentException.class, + () -> new NormalizationProcessor(PROCESSOR_TAG, DESCRIPTION, "", ScoreCombinationTechnique.ARITHMETIC_MEAN.name()) + ); + + expectThrows( + IllegalArgumentException.class, + () -> new NormalizationProcessor(PROCESSOR_TAG, DESCRIPTION, ScoreNormalizationTechnique.MIN_MAX.name(), "") + ); + expectThrows( + IllegalArgumentException.class, + () -> new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + "random_name_for_normalization", + ScoreCombinationTechnique.ARITHMETIC_MEAN.name() + ) + ); + + expectThrows( + IllegalArgumentException.class, + () -> new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + ScoreNormalizationTechnique.MIN_MAX.name(), + "random_name_for_combination" + ) + ); + } + + public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessor normalizationProcessor = spy( + new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + ScoreNormalizationTechnique.MIN_MAX.name(), + ScoreCombinationTechnique.ARITHMETIC_MEAN.name() + ) + ); + + assertEquals(SearchPhaseName.FETCH, normalizationProcessor.getAfterPhase()); + assertEquals(SearchPhaseName.QUERY, normalizationProcessor.getBeforePhase()); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + CompoundTopDocs topDocs = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ) + ) + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + verify(normalizationProcessor, times(1)).updateOriginalQueryResults(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java new file mode 100644 index 000000000..8ef16f6f9 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreCombinerTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[0]; + TopDocsAndMaxScore[] topDocsAndMaxScore = new TopDocsAndMaxScore[] { + new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), 0.0f) }; + scoreCombiner.combineScores(topDocsAndMaxScore, queryTopDocs, ScoreCombinationTechnique.DEFAULT); + } + + public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) }; + + TopDocsAndMaxScore[] topDocsAndMaxScore = new TopDocsAndMaxScore[] { + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) } + ), + 1.0f + ), + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } + ), + 0.9f + ), + new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), 0.0f) }; + scoreCombiner.combineScores(topDocsAndMaxScore, queryTopDocs, ScoreCombinationTechnique.DEFAULT); + + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.length); + + assertEquals(3, queryTopDocs[0].scoreDocs.length); + assertEquals(1.0, queryTopDocs[0].scoreDocs[0].score, 0.001f); + assertEquals(1, queryTopDocs[0].scoreDocs[0].doc); + assertEquals(1.0, queryTopDocs[0].scoreDocs[1].score, 0.001f); + assertEquals(3, queryTopDocs[0].scoreDocs[1].doc); + assertEquals(0.25, queryTopDocs[0].scoreDocs[2].score, 0.001f); + assertEquals(2, queryTopDocs[0].scoreDocs[2].doc); + + assertEquals(4, queryTopDocs[1].scoreDocs.length); + assertEquals(0.9, queryTopDocs[1].scoreDocs[0].score, 0.001f); + assertEquals(2, queryTopDocs[1].scoreDocs[0].doc); + assertEquals(0.6, queryTopDocs[1].scoreDocs[1].score, 0.001f); + assertEquals(4, queryTopDocs[1].scoreDocs[1].doc); + assertEquals(0.5, queryTopDocs[1].scoreDocs[2].score, 0.001f); + assertEquals(7, queryTopDocs[1].scoreDocs[2].doc); + assertEquals(0.01, queryTopDocs[1].scoreDocs[3].score, 0.001f); + assertEquals(9, queryTopDocs[1].scoreDocs[3].doc); + + assertEquals(0, queryTopDocs[2].scoreDocs.length); + + assertEquals(3, topDocsAndMaxScore.length); + assertEquals(1.0, topDocsAndMaxScore[0].maxScore, 0.001f); + assertEquals(0.9, topDocsAndMaxScore[1].maxScore, 0.001f); + assertEquals(0.0, topDocsAndMaxScore[2].maxScore, 0.001f); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java new file mode 100644 index 000000000..1045cf982 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -0,0 +1,372 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_QUERY_TEXT6 = "notexistingword"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final AtomicReference modelId = new AtomicReference<>(); + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + modelId.compareAndSet(null, prepareModel()); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "sum", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + /** + * Using search pipelines with default result processor configs: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 6, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 6, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 6, false); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndNoMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 0, true); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 4, true); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } + + private List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } + + private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + if (totalExpectedDocQty > 0) { + assertEquals(1.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } else { + assertEquals(0.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized + if (totalExpectedDocQty > 0) { + assertEquals(1.0, (double) scores.stream().max(Double::compare).get(), 0.001); + if (assertMinScore) { + assertEquals(0.001, (double) scores.stream().min(Double::compare).get(), 0.001); + } + } + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java new file mode 100644 index 000000000..79455936b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.Range; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreNormalizerTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[0]; + scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })) + ) }; + scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.length); + CompoundTopDocs resultDoc = queryTopDocs[0]; + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(1, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(1, topDoc.scoreDocs.length); + ScoreDoc scoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, scoreDoc.score, 0.001f); + assertEquals(1, scoreDoc.doc); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ) + ) + ) }; + scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.length); + CompoundTopDocs resultDoc = queryTopDocs[0]; + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(3, topDoc.scoreDocs.length); + ScoreDoc highScoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDoc.scoreDocs[topDoc.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ) }; + scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.length); + CompoundTopDocs resultDoc = queryTopDocs[0]; + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(2, resultDoc.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDoc.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); + final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) }; + scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.length); + // shard one + CompoundTopDocs resultDocShardOne = queryTopDocs[0]; + assertEquals(2, resultDocShardOne.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDocShardOne.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[0].score)); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score)); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + + // shard two + CompoundTopDocs resultDocShardTwo = queryTopDocs[1]; + assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardTwoSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryOne.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryOne.scoreDocs); + assertEquals(0, topDocShardTwoSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getCompoundTopDocs().get(1); + assertEquals(4, topDocShardTwoSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryTwo.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryTwo.scoreDocs); + assertEquals(4, topDocShardTwoSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocShardTwoSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(2, topDocShardTwoSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(9, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].doc); + + // shard three + CompoundTopDocs resultDocShardThree = queryTopDocs[2]; + assertEquals(2, resultDocShardThree.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardThreeSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryOne.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getCompoundTopDocs().get(1); + assertEquals(0, topDocShardThreeSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryTwo.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryTwo.scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java new file mode 100644 index 000000000..ec8649ce5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class NormalizationProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testNormalizationProcessor_whenNoParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenWithParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put("normalization", Map.of("technique", "MIN_MAX")); + config.put("combination", Map.of("technique", "ARITHMETIC_MEAN")); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } +}