diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index de55c2b25..d2933f508 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -9,6 +9,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; @@ -97,10 +98,22 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { - log.info("Registering hybrid query phase searcher"); + log.info( + String.format( + Locale.ROOT, + "Registering hybrid query phase searcher with feature flag [%s]", + NEURAL_SEARCH_HYBRID_SEARCH_ENABLED + ) + ); return Optional.of(new HybridQueryPhaseSearcher()); } - log.info("Not registering hybrid query phase searcher because feature flag is disabled"); + log.info( + String.format( + Locale.ROOT, + "Not registering hybrid query phase searcher because feature flag [%s] is disabled", + NEURAL_SEARCH_HYBRID_SEARCH_ENABLED + ) + ); // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core return Optional.empty(); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index cfa6a4187..af9dc6096 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -12,12 +12,12 @@ import java.util.Optional; import java.util.stream.Collectors; -import joptsimple.internal.Strings; import lombok.AccessLevel; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.EnumUtils; +import org.apache.commons.lang3.StringUtils; import org.opensearch.action.search.QueryPhaseResultConsumer; import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchPhaseName; @@ -93,17 +93,14 @@ public void process( return; } - TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsForResult(searchPhaseResult); - CompoundTopDocs[] queryTopDocs = Arrays.stream(topDocsAndMaxScores) + TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsFromSearchPhaseResult(searchPhaseResult); + List queryTopDocs = Arrays.stream(topDocsAndMaxScores) .map(td -> td != null ? (CompoundTopDocs) td.topDocs : null) - .collect(Collectors.toList()) - .toArray(CompoundTopDocs[]::new); + .collect(Collectors.toList()); - ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + ScoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); - ScoreCombiner scoreCombinator = new ScoreCombiner(); - List combinedMaxScores = scoreCombinator.combineScores(queryTopDocs, combinationTechnique); + List combinedMaxScores = ScoreCombiner.combineScores(queryTopDocs, combinationTechnique); updateOriginalQueryResults(searchPhaseResult, queryTopDocs, topDocsAndMaxScores, combinedMaxScores); } @@ -140,10 +137,10 @@ public boolean isIgnoreFailure() { } protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName) { - if (Strings.isNullOrEmpty(normalizationTechniqueName)) { + if (StringUtils.isEmpty(normalizationTechniqueName)) { throw new IllegalArgumentException("normalization technique cannot be empty"); } - if (Strings.isNullOrEmpty(combinationTechniqueName)) { + if (StringUtils.isEmpty(combinationTechniqueName)) { throw new IllegalArgumentException("combination technique cannot be empty"); } if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) { @@ -162,24 +159,22 @@ private boolean isNotHybridQuery(final Optional maybeResult) || !(maybeResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs); } - private TopDocsAndMaxScore[] getCompoundQueryTopDocsForResult( + private TopDocsAndMaxScore[] getCompoundQueryTopDocsFromSearchPhaseResult( final SearchPhaseResults results ) { List preShardResultList = results.getAtomicArray().asList(); TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[results.getAtomicArray().length()]; - int idx = 0; - for (Result shardResult : preShardResultList) { + for (int idx = 0; idx < preShardResultList.size(); idx++) { + Result shardResult = preShardResultList.get(idx); if (shardResult == null) { - idx++; continue; } QuerySearchResult querySearchResult = shardResult.queryResult(); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) { - idx++; continue; } - result[idx++] = topDocsAndMaxScore; + result[idx] = topDocsAndMaxScore; } return result; } @@ -187,14 +182,14 @@ private TopDocsAndMaxScore[] getCompoundQuery @VisibleForTesting protected void updateOriginalQueryResults( final SearchPhaseResults results, - final CompoundTopDocs[] queryTopDocs, + final List queryTopDocs, TopDocsAndMaxScore[] topDocsAndMaxScores, List combinedMaxScores ) { List preShardResultList = results.getAtomicArray().asList(); for (int i = 0; i < preShardResultList.size(); i++) { QuerySearchResult querySearchResult = preShardResultList.get(i).queryResult(); - CompoundTopDocs updatedTopDocs = queryTopDocs[i]; + CompoundTopDocs updatedTopDocs = queryTopDocs.get(i); if (updatedTopDocs == null) { continue; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java index 98604cb83..7e433068d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreCombiner.java @@ -36,10 +36,13 @@ public class ScoreCombiner { * @param combinationTechnique exact combination method that should be applied * @return list of max combined scores for each shard */ - public List combineScores(final CompoundTopDocs[] queryTopDocs, final ScoreCombinationTechnique combinationTechnique) { + public static List combineScores( + final List queryTopDocs, + final ScoreCombinationTechnique combinationTechnique + ) { List maxScores = new ArrayList<>(); - for (int i = 0; i < queryTopDocs.length; i++) { - CompoundTopDocs compoundQueryTopDocs = queryTopDocs[i]; + for (int i = 0; i < queryTopDocs.size(); i++) { + CompoundTopDocs compoundQueryTopDocs = queryTopDocs.get(i); if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) { maxScores.add(ZERO_SCORE); continue; @@ -76,9 +79,7 @@ public List combineScores(final CompoundTopDocs[] queryTopDocs, final Sco (a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a)) ); // we're merging docs with normalized and combined scores. we need to have only maxHits results - for (int docId : normalizedScoresPerDoc.keySet()) { - pq.add(docId); - } + pq.addAll(normalizedScoresPerDoc.keySet()); ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; float maxScore = combinedNormalizedScoresByDocId.get(pq.peek()); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java index 944be8ae8..ed2b6834a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ScoreNormalizer.java @@ -5,7 +5,6 @@ package org.opensearch.neuralsearch.processor; -import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Objects; @@ -28,8 +27,8 @@ public class ScoreNormalizer { * @param queryTopDocs original query results from multiple shards and multiple sub-queries * @param normalizationTechnique exact normalization method that should be applied */ - public void normalizeScores(final CompoundTopDocs[] queryTopDocs, final ScoreNormalizationTechnique normalizationTechnique) { - Optional maybeCompoundQuery = Arrays.stream(queryTopDocs) + public static void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique normalizationTechnique) { + Optional maybeCompoundQuery = queryTopDocs.stream() .filter(topDocs -> Objects.nonNull(topDocs) && !topDocs.getCompoundTopDocs().isEmpty()) .findAny(); if (maybeCompoundQuery.isEmpty()) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java index a2b37d6ee..513ed04b7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinerTests.java @@ -10,7 +10,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.test.OpenSearchTestCase; @@ -18,8 +17,7 @@ public class ScoreCombinerTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreCombiner scoreCombiner = new ScoreCombiner(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[0]; - List maxScores = scoreCombiner.combineScores(queryTopDocs, ScoreCombinationTechnique.DEFAULT); + List maxScores = scoreCombiner.combineScores(List.of(), ScoreCombinationTechnique.DEFAULT); assertNotNull(maxScores); assertEquals(0, maxScores.size()); } @@ -27,7 +25,7 @@ public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() { ScoreCombiner scoreCombiner = new ScoreCombiner(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), List.of( @@ -57,48 +55,33 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ) - ) }; + ) + ); - TopDocsAndMaxScore[] topDocsAndMaxScore = new TopDocsAndMaxScore[] { - new TopDocsAndMaxScore( - new TopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) } - ), - 1.0f - ), - new TopDocsAndMaxScore( - new TopDocs( - new TotalHits(4, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } - ), - 0.9f - ), - new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), 0.0f) }; List combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, ScoreCombinationTechnique.DEFAULT); assertNotNull(queryTopDocs); - assertEquals(3, queryTopDocs.length); + assertEquals(3, queryTopDocs.size()); - assertEquals(3, queryTopDocs[0].scoreDocs.length); - assertEquals(1.0, queryTopDocs[0].scoreDocs[0].score, 0.001f); - assertEquals(1, queryTopDocs[0].scoreDocs[0].doc); - assertEquals(1.0, queryTopDocs[0].scoreDocs[1].score, 0.001f); - assertEquals(3, queryTopDocs[0].scoreDocs[1].doc); - assertEquals(0.25, queryTopDocs[0].scoreDocs[2].score, 0.001f); - assertEquals(2, queryTopDocs[0].scoreDocs[2].doc); + assertEquals(3, queryTopDocs.get(0).scoreDocs.length); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[0].score, 0.001f); + assertEquals(1, queryTopDocs.get(0).scoreDocs[0].doc); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[1].score, 0.001f); + assertEquals(3, queryTopDocs.get(0).scoreDocs[1].doc); + assertEquals(0.25, queryTopDocs.get(0).scoreDocs[2].score, 0.001f); + assertEquals(2, queryTopDocs.get(0).scoreDocs[2].doc); - assertEquals(4, queryTopDocs[1].scoreDocs.length); - assertEquals(0.9, queryTopDocs[1].scoreDocs[0].score, 0.001f); - assertEquals(2, queryTopDocs[1].scoreDocs[0].doc); - assertEquals(0.6, queryTopDocs[1].scoreDocs[1].score, 0.001f); - assertEquals(4, queryTopDocs[1].scoreDocs[1].doc); - assertEquals(0.5, queryTopDocs[1].scoreDocs[2].score, 0.001f); - assertEquals(7, queryTopDocs[1].scoreDocs[2].doc); - assertEquals(0.01, queryTopDocs[1].scoreDocs[3].score, 0.001f); - assertEquals(9, queryTopDocs[1].scoreDocs[3].doc); + assertEquals(4, queryTopDocs.get(1).scoreDocs.length); + assertEquals(0.9, queryTopDocs.get(1).scoreDocs[0].score, 0.001f); + assertEquals(2, queryTopDocs.get(1).scoreDocs[0].doc); + assertEquals(0.6, queryTopDocs.get(1).scoreDocs[1].score, 0.001f); + assertEquals(4, queryTopDocs.get(1).scoreDocs[1].doc); + assertEquals(0.5, queryTopDocs.get(1).scoreDocs[2].score, 0.001f); + assertEquals(7, queryTopDocs.get(1).scoreDocs[2].doc); + assertEquals(0.01, queryTopDocs.get(1).scoreDocs[3].score, 0.001f); + assertEquals(9, queryTopDocs.get(1).scoreDocs[3].doc); - assertEquals(0, queryTopDocs[2].scoreDocs.length); + assertEquals(0, queryTopDocs.get(2).scoreDocs.length); assertEquals(3, combinedMaxScores.size()); assertEquals(1.0, combinedMaxScores.get(0), 0.001f); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java index 79455936b..78674c3d2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizerTests.java @@ -20,22 +20,22 @@ public class ScoreNormalizerTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[0]; - scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); + scoreNormalizer.normalizeScores(List.of(), ScoreNormalizationTechnique.DEFAULT); } @SneakyThrows public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenScoreNormalized() { ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(1, TotalHits.Relation.EQUAL_TO), List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })) - ) }; + ) + ); scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); assertNotNull(queryTopDocs); - assertEquals(1, queryTopDocs.length); - CompoundTopDocs resultDoc = queryTopDocs[0]; + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); assertNotNull(resultDoc.getCompoundTopDocs()); assertEquals(1, resultDoc.getCompoundTopDocs().size()); TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); @@ -51,7 +51,7 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco @SneakyThrows public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), List.of( @@ -60,11 +60,12 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } ) ) - ) }; + ) + ); scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); assertNotNull(queryTopDocs); - assertEquals(1, queryTopDocs.length); - CompoundTopDocs resultDoc = queryTopDocs[0]; + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); assertNotNull(resultDoc.getCompoundTopDocs()); assertEquals(1, resultDoc.getCompoundTopDocs().size()); TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); @@ -82,7 +83,7 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), List.of( @@ -95,12 +96,13 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) ) - ) }; + ) + ); scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); assertNotNull(queryTopDocs); - assertEquals(1, queryTopDocs.length); - CompoundTopDocs resultDoc = queryTopDocs[0]; + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); assertNotNull(resultDoc.getCompoundTopDocs()); assertEquals(2, resultDoc.getCompoundTopDocs().size()); // sub-query one @@ -129,7 +131,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAndDefaultMethod_thenScoreNormalized() { ScoreNormalizer scoreNormalizer = new ScoreNormalizer(); - final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] { + final List queryTopDocs = List.of( new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), List.of( @@ -159,12 +161,13 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ) - ) }; + ) + ); scoreNormalizer.normalizeScores(queryTopDocs, ScoreNormalizationTechnique.DEFAULT); assertNotNull(queryTopDocs); - assertEquals(3, queryTopDocs.length); + assertEquals(3, queryTopDocs.size()); // shard one - CompoundTopDocs resultDocShardOne = queryTopDocs[0]; + CompoundTopDocs resultDocShardOne = queryTopDocs.get(0); assertEquals(2, resultDocShardOne.getCompoundTopDocs().size()); // sub-query one TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0); @@ -190,7 +193,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); // shard two - CompoundTopDocs resultDocShardTwo = queryTopDocs[1]; + CompoundTopDocs resultDocShardTwo = queryTopDocs.get(1); assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size()); // sub-query one TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0); @@ -210,7 +213,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn assertEquals(9, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].doc); // shard three - CompoundTopDocs resultDocShardThree = queryTopDocs[2]; + CompoundTopDocs resultDocShardThree = queryTopDocs.get(2); assertEquals(2, resultDocShardThree.getCompoundTopDocs().size()); // sub-query one TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0);