Skip to content

Commit

Permalink
Refactor code, adding exceptions for case of inconsistent internal state
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 29, 2023
1 parent 03e43fc commit d33370e
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,29 @@
import org.apache.lucene.search.TotalHits;

/**
* Class stores collection of TodDocs for each sub query from hybrid query
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
* list of TopDocs and list of ScoreDoc as well as total hits for the shard.
*/
@ToString(includeFieldNames = true)
@AllArgsConstructor
@Getter
@ToString(includeFieldNames = true)
@Log4j2
public class CompoundTopDocs {

@Getter
@Setter
private TotalHits totalHits;
@Getter
private List<TopDocs> compoundTopDocs;
@Getter
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> compoundTopDocs) {
initialize(totalHits, compoundTopDocs);
public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
initialize(totalHits, topDocs);
}

private void initialize(TotalHits totalHits, List<TopDocs> compoundTopDocs) {
private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
this.totalHits = totalHits;
this.compoundTopDocs = compoundTopDocs;
scoreDocs = cloneLargestScoreDocs(compoundTopDocs);
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -57,7 +58,7 @@ public <Result extends SearchPhaseResult> void process(
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
FetchSearchResult fetchSearchResult = searchPhaseResult.getAtomicArray().asList().get(0).fetchResult();
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
}

Expand Down Expand Up @@ -123,4 +124,11 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class NormalizationProcessorWorkflow {
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final FetchSearchResult fetchSearchResult,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
Expand All @@ -67,7 +67,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, Optional.ofNullable(fetchSearchResult));
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
}

/**
Expand All @@ -82,7 +82,15 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
.map(CompoundTopDocs::new)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
log.warn("Some of querySearchResults are not produced by hybrid query");
log.error(
String.format(
Locale.ROOT,
"sizes of querySearchResults [%d] and queryTopDocs [%d] must match. Most likely some of query results were not formatted correctly by the hybrid query",
querySearchResults.size(),
queryTopDocs.size()
)
);
throw new IllegalStateException("found inconsistent system state while processing score normalization and combination");
}
return queryTopDocs;
}
Expand Down Expand Up @@ -131,7 +139,10 @@ private void updateOriginalFetchResults(
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits, handles (2)
// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void combineShardScores(final ScoreCombinationTechnique scoreCombination
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) {
return;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
// - create map of normalized scores results returned from the single shard
Map<Integer, float[]> normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void normalize(final List<CompoundTopDocs> queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
Expand All @@ -57,17 +57,17 @@ private List<Float> getL2Norm(final List<CompoundTopDocs> queryTopDocs) {
// rest of sub-queries with zero total hits
int numOfSubqueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny()
.get()
.getCompoundTopDocs()
.getTopDocs()
.size();
float[] l2Norms = new float[numOfSubqueries];
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
int bound = topDocsPerSubQuery.size();
for (int index = 0; index < bound; index++) {
for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
int numOfSubqueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny()
.get()
.getCompoundTopDocs()
.getTopDocs()
.size();
// get min scores for each sub query
float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries);
Expand All @@ -54,7 +54,7 @@ public void normalize(final List<CompoundTopDocs> queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
Expand All @@ -71,7 +71,7 @@ private float[] getMaxScores(final List<CompoundTopDocs> queryTopDocs, final int
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
maxScores[j] = Math.max(
maxScores[j],
Expand All @@ -92,7 +92,7 @@ private float[] getMinScores(final List<CompoundTopDocs> queryTopDocs, final int
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
minScores[j] = Math.min(
minScores[j],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ public void normalizeScores(final List<CompoundTopDocs> queryTopDocs, final Scor
}

private boolean canQueryResultsBeNormalized(final List<CompoundTopDocs> queryTopDocs) {
return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getCompoundTopDocs().size() > 0);
return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() {
List<TopDocs> topDocs = List.of(topDocs1, topDocs2);
CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs);
assertNotNull(compoundTopDocs);
assertEquals(topDocs, compoundTopDocs.getCompoundTopDocs());
assertEquals(topDocs, compoundTopDocs.getTopDocs());
}

public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() {
Expand All @@ -49,7 +49,7 @@ public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() {
);
assertNotNull(hybridQueryScoreTopDocs);
assertNotNull(hybridQueryScoreTopDocs.getScoreDocs());
assertNotNull(hybridQueryScoreTopDocs.getCompoundTopDocs());
assertNotNull(hybridQueryScoreTopDocs.getTopDocs());
}

public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWithMostHits() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC

normalizationProcessorWorkflow.execute(
querySearchResults,
null,
Optional.empty(),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);
Expand Down Expand Up @@ -114,7 +115,7 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() {

normalizationProcessorWorkflow.execute(
querySearchResults,
null,
Optional.empty(),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);
Expand Down Expand Up @@ -170,7 +171,7 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo

normalizationProcessorWorkflow.execute(
querySearchResults,
fetchSearchResult,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco
assertNotNull(queryTopDocs);
assertEquals(1, queryTopDocs.size());
CompoundTopDocs resultDoc = queryTopDocs.get(0);
assertNotNull(resultDoc.getCompoundTopDocs());
assertEquals(1, resultDoc.getCompoundTopDocs().size());
TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0);
assertNotNull(resultDoc.getTopDocs());
assertEquals(1, resultDoc.getTopDocs().size());
TopDocs topDoc = resultDoc.getTopDocs().get(0);
assertEquals(1, topDoc.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation);
assertNotNull(topDoc.scoreDocs);
Expand Down Expand Up @@ -67,9 +67,9 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe
assertNotNull(queryTopDocs);
assertEquals(1, queryTopDocs.size());
CompoundTopDocs resultDoc = queryTopDocs.get(0);
assertNotNull(resultDoc.getCompoundTopDocs());
assertEquals(1, resultDoc.getCompoundTopDocs().size());
TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0);
assertNotNull(resultDoc.getTopDocs());
assertEquals(1, resultDoc.getTopDocs().size());
TopDocs topDoc = resultDoc.getTopDocs().get(0);
assertEquals(3, topDoc.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation);
assertNotNull(topDoc.scoreDocs);
Expand Down Expand Up @@ -104,10 +104,10 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe
assertNotNull(queryTopDocs);
assertEquals(1, queryTopDocs.size());
CompoundTopDocs resultDoc = queryTopDocs.get(0);
assertNotNull(resultDoc.getCompoundTopDocs());
assertEquals(2, resultDoc.getCompoundTopDocs().size());
assertNotNull(resultDoc.getTopDocs());
assertEquals(2, resultDoc.getTopDocs().size());
// sub-query one
TopDocs topDocSubqueryOne = resultDoc.getCompoundTopDocs().get(0);
TopDocs topDocSubqueryOne = resultDoc.getTopDocs().get(0);
assertEquals(3, topDocSubqueryOne.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation);
assertNotNull(topDocSubqueryOne.scoreDocs);
Expand All @@ -119,7 +119,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe
assertEquals(0.0, lowScoreDoc.score, 0.001f);
assertEquals(4, lowScoreDoc.doc);
// sub-query two
TopDocs topDocSubqueryTwo = resultDoc.getCompoundTopDocs().get(1);
TopDocs topDocSubqueryTwo = resultDoc.getTopDocs().get(1);
assertEquals(2, topDocSubqueryTwo.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation);
assertNotNull(topDocSubqueryTwo.scoreDocs);
Expand Down Expand Up @@ -169,9 +169,9 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn
assertEquals(3, queryTopDocs.size());
// shard one
CompoundTopDocs resultDocShardOne = queryTopDocs.get(0);
assertEquals(2, resultDocShardOne.getCompoundTopDocs().size());
assertEquals(2, resultDocShardOne.getTopDocs().size());
// sub-query one
TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0);
TopDocs topDocSubqueryOne = resultDocShardOne.getTopDocs().get(0);
assertEquals(3, topDocSubqueryOne.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation);
assertNotNull(topDocSubqueryOne.scoreDocs);
Expand All @@ -183,7 +183,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn
assertEquals(0.0, lowScoreDoc.score, 0.001f);
assertEquals(4, lowScoreDoc.doc);
// sub-query two
TopDocs topDocSubqueryTwo = resultDocShardOne.getCompoundTopDocs().get(1);
TopDocs topDocSubqueryTwo = resultDocShardOne.getTopDocs().get(1);
assertEquals(2, topDocSubqueryTwo.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation);
assertNotNull(topDocSubqueryTwo.scoreDocs);
Expand All @@ -195,15 +195,15 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn

// shard two
CompoundTopDocs resultDocShardTwo = queryTopDocs.get(1);
assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size());
assertEquals(2, resultDocShardTwo.getTopDocs().size());
// sub-query one
TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0);
TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getTopDocs().get(0);
assertEquals(0, topDocShardTwoSubqueryOne.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryOne.totalHits.relation);
assertNotNull(topDocShardTwoSubqueryOne.scoreDocs);
assertEquals(0, topDocShardTwoSubqueryOne.scoreDocs.length);
// sub-query two
TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getCompoundTopDocs().get(1);
TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getTopDocs().get(1);
assertEquals(4, topDocShardTwoSubqueryTwo.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryTwo.totalHits.relation);
assertNotNull(topDocShardTwoSubqueryTwo.scoreDocs);
Expand All @@ -215,14 +215,14 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn

// shard three
CompoundTopDocs resultDocShardThree = queryTopDocs.get(2);
assertEquals(2, resultDocShardThree.getCompoundTopDocs().size());
assertEquals(2, resultDocShardThree.getTopDocs().size());
// sub-query one
TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0);
TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getTopDocs().get(0);
assertEquals(0, topDocShardThreeSubqueryOne.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryOne.totalHits.relation);
assertEquals(0, topDocShardThreeSubqueryOne.scoreDocs.length);
// sub-query two
TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getCompoundTopDocs().get(1);
TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getTopDocs().get(1);
assertEquals(0, topDocShardThreeSubqueryTwo.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryTwo.totalHits.relation);
assertEquals(0, topDocShardThreeSubqueryTwo.scoreDocs.length);
Expand Down
Loading

0 comments on commit d33370e

Please sign in to comment.