Skip to content

Commit

Permalink
Address review comments, part 1
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 19, 2023
1 parent 3ebe720 commit 71c92bf
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 89 deletions.
17 changes: 15 additions & 2 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
Expand Down Expand Up @@ -97,10 +98,22 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
log.info("Registering hybrid query phase searcher");
log.info(
String.format(
Locale.ROOT,
"Registering hybrid query phase searcher with feature flag [%s]",
NEURAL_SEARCH_HYBRID_SEARCH_ENABLED
)
);
return Optional.of(new HybridQueryPhaseSearcher());
}
log.info("Not registering hybrid query phase searcher because feature flag is disabled");
log.info(
String.format(
Locale.ROOT,
"Not registering hybrid query phase searcher because feature flag [%s] is disabled",
NEURAL_SEARCH_HYBRID_SEARCH_ENABLED
)
);
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import java.util.Optional;
import java.util.stream.Collectors;

import joptsimple.internal.Strings;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.EnumUtils;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
Expand Down Expand Up @@ -93,17 +93,14 @@ public <Result extends SearchPhaseResult> void process(
return;

Check warning on line 93 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L93

Added line #L93 was not covered by tests
}

TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsForResult(searchPhaseResult);
CompoundTopDocs[] queryTopDocs = Arrays.stream(topDocsAndMaxScores)
TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsFromSearchPhaseResult(searchPhaseResult);
List<CompoundTopDocs> queryTopDocs = Arrays.stream(topDocsAndMaxScores)
.map(td -> td != null ? (CompoundTopDocs) td.topDocs : null)
.collect(Collectors.toList())
.toArray(CompoundTopDocs[]::new);
.collect(Collectors.toList());

ScoreNormalizer scoreNormalizer = new ScoreNormalizer();
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);
ScoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

ScoreCombiner scoreCombinator = new ScoreCombiner();
List<Float> combinedMaxScores = scoreCombinator.combineScores(queryTopDocs, combinationTechnique);
List<Float> combinedMaxScores = ScoreCombiner.combineScores(queryTopDocs, combinationTechnique);

updateOriginalQueryResults(searchPhaseResult, queryTopDocs, topDocsAndMaxScores, combinedMaxScores);
}
Expand Down Expand Up @@ -140,10 +137,10 @@ public boolean isIgnoreFailure() {
}

protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName) {
if (Strings.isNullOrEmpty(normalizationTechniqueName)) {
if (StringUtils.isEmpty(normalizationTechniqueName)) {
throw new IllegalArgumentException("normalization technique cannot be empty");
}
if (Strings.isNullOrEmpty(combinationTechniqueName)) {
if (StringUtils.isEmpty(combinationTechniqueName)) {
throw new IllegalArgumentException("combination technique cannot be empty");
}
if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) {
Expand All @@ -162,39 +159,37 @@ private boolean isNotHybridQuery(final Optional<SearchPhaseResult> maybeResult)
|| !(maybeResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
}

private <Result extends SearchPhaseResult> TopDocsAndMaxScore[] getCompoundQueryTopDocsForResult(
private <Result extends SearchPhaseResult> TopDocsAndMaxScore[] getCompoundQueryTopDocsFromSearchPhaseResult(
final SearchPhaseResults<Result> results
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[results.getAtomicArray().length()];
int idx = 0;
for (Result shardResult : preShardResultList) {
for (int idx = 0; idx < preShardResultList.size(); idx++) {
Result shardResult = preShardResultList.get(idx);
if (shardResult == null) {
idx++;
continue;

Check warning on line 170 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L170

Added line #L170 was not covered by tests
}
QuerySearchResult querySearchResult = shardResult.queryResult();
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) {
idx++;
continue;

Check warning on line 175 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L175

Added line #L175 was not covered by tests
}
result[idx++] = topDocsAndMaxScore;
result[idx] = topDocsAndMaxScore;
}
return result;
}

@VisibleForTesting
protected <Result extends SearchPhaseResult> void updateOriginalQueryResults(
final SearchPhaseResults<Result> results,
final CompoundTopDocs[] queryTopDocs,
final List<CompoundTopDocs> queryTopDocs,
TopDocsAndMaxScore[] topDocsAndMaxScores,
List<Float> combinedMaxScores
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
for (int i = 0; i < preShardResultList.size(); i++) {
QuerySearchResult querySearchResult = preShardResultList.get(i).queryResult();
CompoundTopDocs updatedTopDocs = queryTopDocs[i];
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
if (updatedTopDocs == null) {
continue;

Check warning on line 194 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L194

Added line #L194 was not covered by tests
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ public class ScoreCombiner {
* @param combinationTechnique exact combination method that should be applied
* @return list of max combined scores for each shard
*/
public List<Float> combineScores(final CompoundTopDocs[] queryTopDocs, final ScoreCombinationTechnique combinationTechnique) {
public static List<Float> combineScores(
final List<CompoundTopDocs> queryTopDocs,
final ScoreCombinationTechnique combinationTechnique
) {
List<Float> maxScores = new ArrayList<>();
for (int i = 0; i < queryTopDocs.length; i++) {
CompoundTopDocs compoundQueryTopDocs = queryTopDocs[i];
for (int i = 0; i < queryTopDocs.size(); i++) {
CompoundTopDocs compoundQueryTopDocs = queryTopDocs.get(i);
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) {
maxScores.add(ZERO_SCORE);
continue;
Expand Down Expand Up @@ -76,9 +79,7 @@ public List<Float> combineScores(final CompoundTopDocs[] queryTopDocs, final Sco
(a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))
);
// we're merging docs with normalized and combined scores. we need to have only maxHits results
for (int docId : normalizedScoresPerDoc.keySet()) {
pq.add(docId);
}
pq.addAll(normalizedScoresPerDoc.keySet());

ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];
float maxScore = combinedNormalizedScoresByDocId.get(pq.peek());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -28,8 +27,8 @@ public class ScoreNormalizer {
* @param queryTopDocs original query results from multiple shards and multiple sub-queries
* @param normalizationTechnique exact normalization method that should be applied
*/
public void normalizeScores(final CompoundTopDocs[] queryTopDocs, final ScoreNormalizationTechnique normalizationTechnique) {
Optional<CompoundTopDocs> maybeCompoundQuery = Arrays.stream(queryTopDocs)
public static void normalizeScores(final List<CompoundTopDocs> queryTopDocs, final ScoreNormalizationTechnique normalizationTechnique) {
Optional<CompoundTopDocs> maybeCompoundQuery = queryTopDocs.stream()
.filter(topDocs -> Objects.nonNull(topDocs) && !topDocs.getCompoundTopDocs().isEmpty())
.findAny();
if (maybeCompoundQuery.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.test.OpenSearchTestCase;

public class ScoreCombinerTests extends OpenSearchTestCase {

public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() {
ScoreCombiner scoreCombiner = new ScoreCombiner();
final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[0];
List<Float> maxScores = scoreCombiner.combineScores(queryTopDocs, ScoreCombinationTechnique.DEFAULT);
List<Float> maxScores = scoreCombiner.combineScores(List.of(), ScoreCombinationTechnique.DEFAULT);
assertNotNull(maxScores);
assertEquals(0, maxScores.size());
}

public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() {
ScoreCombiner scoreCombiner = new ScoreCombiner();

final CompoundTopDocs[] queryTopDocs = new CompoundTopDocs[] {
final List<CompoundTopDocs> queryTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
List.of(
Expand Down Expand Up @@ -57,48 +55,33 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0])
)
) };
)
);

TopDocsAndMaxScore[] topDocsAndMaxScore = new TopDocsAndMaxScore[] {
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) }
),
1.0f
),
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) }
),
0.9f
),
new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), 0.0f) };
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, ScoreCombinationTechnique.DEFAULT);

assertNotNull(queryTopDocs);
assertEquals(3, queryTopDocs.length);
assertEquals(3, queryTopDocs.size());

assertEquals(3, queryTopDocs[0].scoreDocs.length);
assertEquals(1.0, queryTopDocs[0].scoreDocs[0].score, 0.001f);
assertEquals(1, queryTopDocs[0].scoreDocs[0].doc);
assertEquals(1.0, queryTopDocs[0].scoreDocs[1].score, 0.001f);
assertEquals(3, queryTopDocs[0].scoreDocs[1].doc);
assertEquals(0.25, queryTopDocs[0].scoreDocs[2].score, 0.001f);
assertEquals(2, queryTopDocs[0].scoreDocs[2].doc);
assertEquals(3, queryTopDocs.get(0).scoreDocs.length);
assertEquals(1.0, queryTopDocs.get(0).scoreDocs[0].score, 0.001f);
assertEquals(1, queryTopDocs.get(0).scoreDocs[0].doc);
assertEquals(1.0, queryTopDocs.get(0).scoreDocs[1].score, 0.001f);
assertEquals(3, queryTopDocs.get(0).scoreDocs[1].doc);
assertEquals(0.25, queryTopDocs.get(0).scoreDocs[2].score, 0.001f);
assertEquals(2, queryTopDocs.get(0).scoreDocs[2].doc);

assertEquals(4, queryTopDocs[1].scoreDocs.length);
assertEquals(0.9, queryTopDocs[1].scoreDocs[0].score, 0.001f);
assertEquals(2, queryTopDocs[1].scoreDocs[0].doc);
assertEquals(0.6, queryTopDocs[1].scoreDocs[1].score, 0.001f);
assertEquals(4, queryTopDocs[1].scoreDocs[1].doc);
assertEquals(0.5, queryTopDocs[1].scoreDocs[2].score, 0.001f);
assertEquals(7, queryTopDocs[1].scoreDocs[2].doc);
assertEquals(0.01, queryTopDocs[1].scoreDocs[3].score, 0.001f);
assertEquals(9, queryTopDocs[1].scoreDocs[3].doc);
assertEquals(4, queryTopDocs.get(1).scoreDocs.length);
assertEquals(0.9, queryTopDocs.get(1).scoreDocs[0].score, 0.001f);
assertEquals(2, queryTopDocs.get(1).scoreDocs[0].doc);
assertEquals(0.6, queryTopDocs.get(1).scoreDocs[1].score, 0.001f);
assertEquals(4, queryTopDocs.get(1).scoreDocs[1].doc);
assertEquals(0.5, queryTopDocs.get(1).scoreDocs[2].score, 0.001f);
assertEquals(7, queryTopDocs.get(1).scoreDocs[2].doc);
assertEquals(0.01, queryTopDocs.get(1).scoreDocs[3].score, 0.001f);
assertEquals(9, queryTopDocs.get(1).scoreDocs[3].doc);

assertEquals(0, queryTopDocs[2].scoreDocs.length);
assertEquals(0, queryTopDocs.get(2).scoreDocs.length);

assertEquals(3, combinedMaxScores.size());
assertEquals(1.0, combinedMaxScores.get(0), 0.001f);
Expand Down
Loading

0 comments on commit 71c92bf

Please sign in to comment.