Skip to content

Commit

Permalink
add feature flag for remote inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhangxunmt committed Aug 16, 2023
1 parent 5c2dc5d commit f46432e
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -52,6 +53,7 @@
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
Expand Down Expand Up @@ -83,6 +85,7 @@ public class TransportDeployModelAction extends HandledTransportAction<ActionReq

private volatile boolean allowCustomDeploymentPlan;
private ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportDeployModelAction(
Expand All @@ -99,7 +102,8 @@ public TransportDeployModelAction(
MLModelManager mlModelManager,
MLStats mlStats,
Settings settings,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLDeployModelAction.NAME, transportService, actionFilters, MLDeployModelRequest::new);
this.transportService = transportService;
Expand All @@ -114,6 +118,7 @@ public TransportDeployModelAction(
this.mlModelManager = mlModelManager;
this.mlStats = mlStats;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
clusterService
.getClusterSettings()
Expand All @@ -129,6 +134,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
if (mlModel.getAlgorithm() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import org.opensearch.ml.rest.RestMLUpdateModelGroupAction;
import org.opensearch.ml.rest.RestMLUploadModelChunkAction;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLClusterLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
Expand Down Expand Up @@ -219,6 +220,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {

private ConnectorAccessControlHelper connectorAccessControlHelper;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down Expand Up @@ -325,6 +328,8 @@ public Collection<Object> createComponents(
mlInputDatasetHandler = new MLInputDatasetHandler(client);
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);

mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper);

MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper);
Expand Down Expand Up @@ -431,6 +436,7 @@ public Collection<Object> createComponents(
mlExecuteTaskRunner,
modelAccessControlHelper,
connectorAccessControlHelper,
mlFeatureEnabledSetting,
mlSearchHandler,
mlTaskDispatcher,
mlModelChunkUploader,
Expand All @@ -455,7 +461,7 @@ public List<RestHandler> getRestHandlers(
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils);
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager);
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting);
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction();
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction();
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction();
Expand All @@ -464,7 +470,7 @@ public List<RestHandler> getRestHandlers(
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();
RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService);
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings);
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting);
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
Expand All @@ -473,7 +479,7 @@ public List<RestHandler> getRestHandlers(
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction();
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting);
RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction();
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction();
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction();
Expand Down Expand Up @@ -606,7 +612,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD,
MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.io.IOException;
import java.util.List;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -26,11 +28,15 @@

public class RestMLCreateConnectorAction extends BaseRestHandler {
private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor *
* Constructor
* @param mlFeatureEnabledSetting
*/
public RestMLCreateConnectorAction() {}
public RestMLCreateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -56,6 +62,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLCreateConnectorRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (!request.hasContent()) {
throw new IOException("Create Connector request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand All @@ -45,11 +47,14 @@ public class RestMLPredictionAction extends BaseRestHandler {

private MLModelManager modelManager;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLPredictionAction(MLModelManager modelManager) {
public RestMLPredictionAction(MLModelManager modelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.modelManager = modelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand Down Expand Up @@ -117,6 +122,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
if (algorithm.equals("REMOTE") && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION;
Expand All @@ -20,9 +21,11 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -33,20 +36,24 @@
public class RestMLRegisterModelAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action";
private volatile boolean isModelUrlAllowed;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLRegisterModelAction() {}
public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

/**
* Constructor
* @param clusterService cluster service
* @param settings settings
*/
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings) {
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand Down Expand Up @@ -93,6 +100,9 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (mlInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.remote_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.settings;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

public class MLFeatureEnabledSetting {

private volatile Boolean isRemoteInferenceEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
}

/**
* Whether the remote inference feature is enabled. If disabled, time series plugin rejects RESTful requests for this feature.
* @return whether Remote Inference is enabled.
*/
public boolean isRemoteInferenceEnabled() {
return isRemoteInferenceEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
public class MLExceptionUtils {

public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is disabled. To enable update the setting plugins.ml_commons.remote_inference_enabled to true";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Loading

0 comments on commit f46432e

Please sign in to comment.