Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for Default Model Id #337

Merged
merged 19 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
Expand All @@ -41,6 +42,7 @@
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand All @@ -50,6 +52,7 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -78,6 +81,7 @@
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);

Check warning on line 84 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L84

Added line #L84 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start getting rid of these kind of inits and move towards singleton pattern without this kind of inits.

May be an AI for maintainers

NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
Expand Down Expand Up @@ -127,4 +131,11 @@
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(
Parameters parameters
) {
return Map.of(NeuralQueryProcessor.TYPE, new NeuralQueryProcessor.Factory());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

public class NeuralQueryProcessor extends AbstractProcessor implements SearchRequestProcessor {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename class to something closer to a processor type name, although I left the final call to you


/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "neural_query";
navneet1v marked this conversation as resolved.
Show resolved Hide resolved

final String modelId;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved

final Map<String, Object> neuralFieldDefaultIdMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private? Also, could we add a comment for this member?

Also, why is this map <String, Object> and not <String, String>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because ConfigurationUtils The class from Opensearch which is used to retrieve map from processor is generic and it Map<String, Object> defined in it. ConfigurationUtils is the designed to get values from the processor in OS. So I am using the same class


/**
* Returns the type of the processor.
*
* @return The processor type.
*/
@Override
public String getType() {
return TYPE;

Check warning on line 36 in src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java#L36

Added line #L36 was not covered by tests
}

protected NeuralQueryProcessor(
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
String tag,
String description,
boolean ignoreFailure,
String modelId,
Map<String, Object> neuralFieldDefaultIdMap
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap;
}

@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
return searchRequest;
}

public static class Factory implements Processor.Factory<SearchRequestProcessor> {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id";

@Override
public NeuralQueryProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
String modelId = (String) config.remove(DEFAULT_MODEL_ID);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
}

return new NeuralQueryProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -37,6 +38,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -82,6 +84,7 @@
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;

/**
* Constructor from stream input
Expand All @@ -93,7 +96,11 @@
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
if (isClusterOnOrAfterMinRequiredVersion()) {
this.modelId = in.readOptionalString();
} else {
this.modelId = in.readString();

Check warning on line 102 in src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L102

Added line #L102 was not covered by tests
}
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
Expand All @@ -102,7 +109,11 @@
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
if (isClusterOnOrAfterMinRequiredVersion()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);

Check warning on line 115 in src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L115

Added line #L115 was not covered by tests
}
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}
Expand All @@ -112,7 +123,9 @@
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
if (!isClusterOnOrAfterMinRequiredVersion() || (isClusterOnOrAfterMinRequiredVersion() && modelId != null)) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
Expand Down Expand Up @@ -164,8 +177,9 @@
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

if (!isClusterOnOrAfterMinRequiredVersion()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

Check warning on line 181 in src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L181

Added line #L181 was not covered by tests
}
return neuralQueryBuilder;
}

Expand Down Expand Up @@ -258,4 +272,8 @@
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.query.visitor;

import java.util.Map;

import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {

private String modelId;
private Map<String, Object> neuralFieldMap;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved

public NeuralSearchQueryVisitor(String modelId, Map<String, Object> neuralFieldMap) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
this.modelId = modelId;
this.neuralFieldMap = neuralFieldMap;
}

@Override
public void accept(QueryBuilder queryBuilder) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);

Check warning on line 33 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L32-L33

Added lines #L32 - L33 were not covered by tests
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(

Check warning on line 37 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L37

Added line #L37 was not covered by tests
"model id must be provided in neural query or a default model id must be set in search request processor"
);
}
}
}

@Override
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;

Check warning on line 46 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L46

Added line #L46 was not covered by tests
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import java.util.Locale;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class NeuralSearchClusterUtil {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private ClusterService clusterService;

private static NeuralSearchClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized NeuralSearchClusterUtil instance() {
if (instance == null) {
instance = new NeuralSearchClusterUtil();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
try {
return this.clusterService.state().getNodes().getMinNodeVersion();
} catch (Exception exception) {
log.error(
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
String.format(
Locale.ROOT,
"Failed to get cluster minimum node version, returning current node version %s instead.",
Version.CURRENT
),
exception
);
return Version.CURRENT;
}
}

}
Loading
Loading