Skip to content

Commit

Permalink
Fixed nested field case, draft version
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 22, 2023
1 parent b3c73bd commit fdcb5d8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
Expand All @@ -48,24 +51,73 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

final static int MAX_NESTED_SUBQUERY_LIMIT = 20;

public HybridQueryPhaseSearcher() {
super();
}

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
if (isHybridQuery(query, searchContext)) {
query = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
validateHybridQuery(query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

void validateHybridQuery(final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), 1);
}
}
}

void validateNestedBooleanQuery(final Query query, int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level >= MAX_NESTED_SUBQUERY_LIMIT) {
throw new IllegalStateException("reached max nested query limit, cannot process query");
}
if (query instanceof BooleanQuery) {
for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) {
validateNestedBooleanQuery(booleanClause.getQuery(), level + 1);
}
}
}

private Query extractHybridQuery(SearchContext searchContext, Query query) {
if (query instanceof BooleanQuery
&& new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query)
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery)) {
//extract hybrid query and replace bool with hybrid query
query = ((BooleanQuery) query).clauses().get(0).getQuery();
}
return query;
}

boolean isHybridQuery(Query query, SearchContext searchContext) {
if (query instanceof HybridQuery) {
return true;
}
else if (new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query)
&& query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery)) {
return true;
}
return false;
}

@VisibleForTesting
protected boolean searchWithCollector(
final SearchContext searchContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -204,6 +205,8 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()

Query query = termSubQuery.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
MapperService mapperService = mock(MapperService.class);
when(searchContext.mapperService()).thenReturn(mapperService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down

0 comments on commit fdcb5d8

Please sign in to comment.