From 3991fe7c04e2f2c104c0be0365428c5747f21f6c Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 30 Nov 2023 19:26:16 -0800 Subject: [PATCH] Fixed Hybrid query for cases when it's wrapped into other compound queries (#498) * Fixed nested field case Signed-off-by: Martin Gaievski --- CHANGELOG.md | 5 +- .../query/HybridQueryPhaseSearcher.java | 124 ++++- .../common/BaseNeuralSearchIT.java | 44 +- .../neuralsearch/query/HybridQueryIT.java | 131 ++++- .../query/HybridQueryPhaseSearcherTests.java | 447 +++++++++++++++++- 5 files changed, 735 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa866ca7e..093dbff70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438)) -Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490) +- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) +- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) +- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index f65e30222..26f580364 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -19,12 +19,19 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.SeqNoFieldMapper; +import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; @@ -60,12 +67,120 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (query instanceof HybridQuery) { - return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + if (isHybridQuery(query, searchContext)) { + Query hybridQuery = extractHybridQuery(searchContext, query); + return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } + validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { + /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code + https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. + main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks + hybrid query for indexes with nested field types. + in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for + this search request. + below is sample structure of such query: + + Boolean { + should: { + hybrid: { + sub_query1 {} + sub_query2 {} + } + } + filter: { + exists: { + field: "_primary_term" + } + } + } + TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ + // we have already checked if query in instance of Boolean in higher level else if condition + return ((BooleanQuery) query).clauses() + .stream() + .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .allMatch(clause -> { + return clause.getOccur() == BooleanClause.Occur.FILTER + && clause.getQuery() instanceof FieldExistsQuery + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); + }); + } + return false; + } + + private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); + } + + private boolean isWrappedHybridQuery(final Query query) { + return query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); + } + + private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + if (hasNestedFieldOrNestedDocs(query, searchContext) + && isWrappedHybridQuery(query) + && ((BooleanQuery) query).clauses().size() > 0) { + // extract hybrid query and replace bool with hybrid query + List booleanClauses = ((BooleanQuery) query).clauses(); + if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) { + throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query"); + } + return booleanClauses.get(0).getQuery(); + } + return query; + } + + /** + * Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases + * when hybrid query is wrapped into other compound queries. + * For example, if we have Bool query like below we need to throw an error + * bool: { + * should: [ + * match: {}, + * hybrid: { + * sub_query1 {} + * sub_query2 {} + * } + * ] + * } + * TODO add similar validation for other compound type queries like dis_max, constant_score etc. + * + * @param query query to validate + */ + private void validateQuery(final SearchContext searchContext, final Query query) { + if (query instanceof BooleanQuery) { + List booleanClauses = ((BooleanQuery) query).clauses(); + for (BooleanClause booleanClause : booleanClauses) { + validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext)); + } + } + } + + private void validateNestedBooleanQuery(final Query query, final int level) { + if (query instanceof HybridQuery) { + throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); + } + if (level <= 0) { + // ideally we should throw an error here but this code is on the main search workflow path and that might block + // execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect + // results in case hybrid query is wrapped into such bool query + log.error("reached max nested query limit, cannot process bool query with that many nested clauses"); + return; + } + if (query instanceof BooleanQuery) { + for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { + validateNestedBooleanQuery(booleanClause.getQuery(), level - 1); + } + } + } + @VisibleForTesting protected boolean searchWithCollector( final SearchContext searchContext, @@ -209,4 +324,9 @@ private float getMaxScore(final List topDocs) { private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { return sortAndFormats == null ? null : sortAndFormats.formats; } + + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..e3e57a141 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts + ) { + addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); + } + /** * Add a set of knn vectors and text to an index * @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam * @param vectors List of vectors corresponding to those fields * @param textFieldNames List of text fields to be added * @param texts List of text corresponding to those fields + * @param nestedFieldNames List of nested fields to be added + * @param nestedFields List of fields and values corresponding to those fields */ @SneakyThrows protected void addKnnDoc( @@ -430,7 +444,9 @@ protected void addKnnDoc( List vectorFieldNames, List vectors, List textFieldNames, - List texts + List texts, + List nestedFieldNames, + List> nestedFields ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -441,6 +457,16 @@ protected void addKnnDoc( for (int i = 0; i < textFieldNames.size(); i++) { builder.field(textFieldNames.get(i), texts.get(i)); } + + for (int i = 0; i < nestedFieldNames.size(); i++) { + builder.field(nestedFieldNames.get(i)); + builder.startObject(); + Map nestedValues = nestedFields.get(i); + for (Map.Entry entry : nestedValues.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } builder.endObject(); request.setJsonEntity(builder.toString()); @@ -523,7 +549,16 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { + protected String buildIndexConfiguration(final List knnFieldConfigs, final int numberOfShards) { + return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards); + } + + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final int numberOfShards + ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") @@ -544,6 +579,11 @@ private String buildIndexConfiguration(List knnFieldConfigs, int .endObject() .endObject(); } + + for (String nestedField : nestedFields) { + xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); + } + xContentBuilder.endObject().endObject().endObject(); return xContentBuilder.toString(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 171d2f4a4..4a8f0d065 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -5,6 +5,9 @@ package org.opensearch.neuralsearch.query; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.query.QueryBuilders.matchQuery; import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -19,10 +22,13 @@ import lombok.SneakyThrows; +import org.apache.lucene.search.join.ScoreMode; import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -35,6 +41,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = + "test-neural-multi-doc-nested-type--single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -46,9 +54,14 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String NESTED_FIELD_1 = "firstname"; + private static final String NESTED_FIELD_2 = "lastname"; + private static final String NESTED_FIELD_1_VALUE = "john"; + private static final String NESTED_FIELD_2_VALUE = "black"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -191,7 +204,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( } @SneakyThrows - public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() { + public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenFail() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -202,23 +215,104 @@ public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder); + ResponseException exceptionNoNestedTypes = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, boolQueryBuilder, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionNoNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + ResponseException exceptionQWithNestedTypes = expectThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + boolQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionQWithNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + } + + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - boolQueryBuilder, + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertTrue(getHitCount(searchResponseAsMap) > 0); + assertEquals(1, getHitCount(searchResponseAsMap)); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f); + assertEquals(0.5f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); - assertTrue((int) total.get("value") > 0); + assertEquals(1, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + NestedQueryBuilder nestedQueryBuilder = QueryBuilders.nestedQuery( + TEST_NESTED_TYPE_FIELD_NAME_1, + matchQuery(TEST_NESTED_TYPE_FIELD_NAME_1 + "." + NESTED_FIELD_1, NESTED_FIELD_1_VALUE), + ScoreMode.Total + ); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(nestedQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(1, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertEquals(0.5f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(1, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -284,6 +378,31 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } + + if (TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + 1 + ), + "" + ); + + addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + List.of(), + List.of(), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) + ); + } } private void addDocsToIndex(final String testMultiDocIndexName) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e9c55cc54..c4f3f4a3e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -5,8 +5,10 @@ package org.opensearch.neuralsearch.search.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -21,6 +23,8 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.Set; +import java.util.UUID; import lombok.SneakyThrows; @@ -30,6 +34,8 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -38,10 +44,16 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.Settings; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -73,6 +85,8 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final int K = 10; private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + private static final UUID INDEX_UUID = UUID.randomUUID(); + private static final String TEST_INDEX = "index"; @SneakyThrows public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { @@ -82,7 +96,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -125,6 +140,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -151,7 +167,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -195,6 +212,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -217,7 +235,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -265,6 +284,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -310,7 +330,8 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -360,6 +381,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -404,6 +426,412 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); + + Query query = boolQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("user"); + b.field("type", "nested"); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER) + .add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("cannot process hybrid query due to incorrect structure of top level bool query") + ); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("user"); + b.field("type", "nested"); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); + when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD) + .add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + Query query = createNestedBoolQuery( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2).toQuery(mockQueryShardContext), + (int) (MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.getDefault(null) + 1) + ); + + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertFalse(isHybridQueryStartStopElement(scoreDocs[0])); + assertFalse(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + + releaseResources(directory, w, reader); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); @@ -447,4 +875,15 @@ private List getSubQueryResultsForSingleShard(final TopDocs topDocs) { } return topDocsList; } + + private BooleanQuery createNestedBoolQuery(final Query query1, final Query query2, int level) { + if (level == 0) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(query1, BooleanClause.Occur.SHOULD).add(query2, BooleanClause.Occur.SHOULD); + return builder.build(); + } + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(createNestedBoolQuery(query1, query2, level - 1), BooleanClause.Occur.MUST); + return builder.build(); + } }