Skip to content

Commit

Permalink
Adding hybrid_search_enabled settings
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 1, 2023
1 parent 7dda0c5 commit 1fa7e78
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.Setting;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.index.NeuralSearchSettings;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
Expand All @@ -45,6 +48,7 @@
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

private MLCommonsClientAccessor clientAccessor;
private ClusterService clusterService;

@Override
public Collection<Object> createComponents(
Expand All @@ -60,6 +64,7 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchSettings.state().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
return List.of(clientAccessor);
}
Expand All @@ -82,4 +87,11 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
}

@Override
public List<Setting<?>> getSettings() {
return List.of(
NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.index.NeuralSearchSettings;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
Expand Down Expand Up @@ -51,7 +52,8 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
boolean isQuerySearcherEnabled = NeuralSearchSettings.state().getSettingValue(NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH);
if (isQuerySearcherEnabled && query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.index.NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH;
import static org.opensearch.neuralsearch.index.NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING;

import java.io.IOException;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import lombok.SneakyThrows;

Expand All @@ -34,9 +40,16 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.action.OriginalIndices;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.Index;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand All @@ -45,6 +58,8 @@
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.ShardId;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.neuralsearch.index.NeuralSearchSettings;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
Expand Down Expand Up @@ -127,6 +142,14 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build();
Set<Setting<?>> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
ClusterService clusterService = new ClusterService(settings, clusterSettings, null);
NeuralSearchSettings.state().initialize(clusterService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -185,6 +208,89 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()

Query query = termSubQuery.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build();
Set<Setting<?>> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
ClusterService clusterService = new ClusterService(settings, clusterSettings, null);
NeuralSearchSettings.state().initialize(clusterService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

releaseResources(directory, w, reader);

verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean());
}

@SneakyThrows
public void testSettings_whenHybridSearchDisabled_thenDoNotCallHybridDocCollector() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
Setting<Boolean> setting = Setting.boolSetting(
INDEX_NEURAL_SEARCH_HYBRID_SEARCH,
false,
Setting.Property.NodeScope
);
Settings settings = Settings.builder().put(setting.getKey(), false).build();
Set<Setting<?>> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING);
ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet);
ClusterService clusterService = new ClusterService(settings, clusterSettings, null);
NeuralSearchSettings.state().initialize(clusterService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -251,6 +357,12 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build();
Set<Setting<?>> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
ClusterService clusterService = new ClusterService(settings, clusterSettings, null);
NeuralSearchSettings.state().initialize(clusterService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down

0 comments on commit 1fa7e78

Please sign in to comment.