diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 93f3c458c1..828d6c3bd4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.time.Instant; import java.util.ArrayList; @@ -52,6 +53,7 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; @@ -83,6 +85,7 @@ public class TransportDeployModelAction extends HandledTransportAction { + if (mlModel.getAlgorithm() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { if (!access) { listener diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ce72558831..c873e56ea6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -150,6 +150,7 @@ import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; @@ -219,6 +220,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private ConnectorAccessControlHelper connectorAccessControlHelper; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Override public List> getActions() { return ImmutableList @@ -325,6 +328,8 @@ public Collection createComponents( mlInputDatasetHandler = new MLInputDatasetHandler(client); modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings); + mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper); @@ -431,6 +436,7 @@ public Collection createComponents( mlExecuteTaskRunner, modelAccessControlHelper, connectorAccessControlHelper, + mlFeatureEnabledSetting, mlSearchHandler, mlTaskDispatcher, mlModelChunkUploader, @@ -455,7 +461,7 @@ public List getRestHandlers( RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils); RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction(); RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction(); - RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager); + RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting); RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(); RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(); RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(); @@ -464,7 +470,11 @@ public List getRestHandlers( RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction(); RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction(); RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService); - RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings); + RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction( + clusterService, + settings, + mlFeatureEnabledSetting + ); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(); RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); @@ -473,7 +483,7 @@ public List getRestHandlers( RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); - RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(); + RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); @@ -606,7 +616,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, - MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX + MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, + MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java index 662e635505..a1e05ce7d6 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.io.IOException; import java.util.List; @@ -17,6 +18,7 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -26,11 +28,15 @@ public class RestMLCreateConnectorAction extends BaseRestHandler { private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** - * Constructor * + * Constructor + * @param mlFeatureEnabledSetting */ - public RestMLCreateConnectorAction() {} + public RestMLCreateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -56,6 +62,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLCreateConnectorRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } if (!request.hasContent()) { throw new IOException("Create Connector request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 695a095ea8..84853d427c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -29,6 +30,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -45,11 +47,14 @@ public class RestMLPredictionAction extends BaseRestHandler { private MLModelManager modelManager; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + /** * Constructor */ - public RestMLPredictionAction(MLModelManager modelManager) { + public RestMLPredictionAction(MLModelManager modelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) { this.modelManager = modelManager; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -117,6 +122,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { + if (algorithm.equals("REMOTE") && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLInput mlInput = MLInput.parse(parser, algorithm); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 9e33022290..9e76a48c97 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION; @@ -20,9 +21,11 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -33,20 +36,24 @@ public class RestMLRegisterModelAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action"; private volatile boolean isModelUrlAllowed; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLRegisterModelAction() {} + public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } /** * Constructor * @param clusterService cluster service * @param settings settings */ - public RestMLRegisterModelAction(ClusterService clusterService, Settings settings) { + public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -93,6 +100,9 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel); + if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } if (mlInput.getUrl() != null && !isModelUrlAllowed) { throw new IllegalArgumentException( "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 9f1a5308a2..69f1397f9a 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -108,6 +108,9 @@ private MLCommonsSettings() {} Setting.Property.Dynamic ); + public static final Setting ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting + .boolSetting("plugins.ml_commons.remote_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java new file mode 100644 index 0000000000..4f0ff6713e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -0,0 +1,34 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.settings; + +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; + +public class MLFeatureEnabledSetting { + + private volatile Boolean isRemoteInferenceEnabled; + + public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { + isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it); + } + + /** + * Whether the remote inference feature is enabled. If disabled, time series plugin rejects RESTful requests for this feature. + * @return whether Remote Inference is enabled. + */ + public boolean isRemoteInferenceEnabled() { + return isRemoteInferenceEnabled; + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index e59e2ee0fe..abf347f80d 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -18,6 +18,8 @@ public class MLExceptionUtils { public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: "; + public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG = + "Remote Inference is disabled. To enable update the setting plugins.ml_commons.remote_inference_enabled to true"; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index c355da2d38..0d92f39571 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.lang.reflect.Field; import java.nio.file.Path; @@ -65,6 +66,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; @@ -130,6 +132,9 @@ public class TransportDeployModelActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private final List eligibleNodes = mock(List.class); @Rule @@ -167,6 +172,8 @@ public void setup() { return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat); transportDeployModelAction = new TransportDeployModelAction( @@ -183,7 +190,8 @@ public void setup() { mlModelManager, mlStats, settings, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ); } @@ -231,6 +239,23 @@ public void testDoExecute_userHasNoAccessException() { assertEquals("User Doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } + public void testDoExecuteRemoteInferenceDisabled() { + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + ActionListener deployModelResponseListener = mock(ActionListener.class); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(deployModelResponseListener).onFailure(argumentCaptor.capture()); + assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage()); + } + public void test_ValidationFailedException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); @@ -277,7 +302,8 @@ public void testDoExecute_DoNotAllowCustomDeploymentPlan() { mlModelManager, mlStats, settings, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, mock(ActionListener.class)); @@ -302,7 +328,8 @@ public void testDoExecute_whenDeployModelRequestNodeIdsEmpty_thenMLResourceNotFo mlModelManager, mlStats, settings, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); MLDeployModelRequest MLDeployModelRequest1 = mock(MLDeployModelRequest.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java index 5ee78a200b..07f823f905 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java @@ -11,6 +11,8 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.getCreateConnectorRestRequest; import static org.opensearch.ml.utils.TestHelper.verifyParsedCreateConnectorInput; @@ -24,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -33,6 +36,7 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -53,9 +57,14 @@ public class RestMLCreateConnectorActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLCreateConnectorAction = new RestMLCreateConnectorAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -75,7 +84,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLCreateConnectorAction mlCreateConnectorAction = new RestMLCreateConnectorAction(); + RestMLCreateConnectorAction mlCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); assertNotNull(mlCreateConnectorAction); } @@ -120,4 +129,13 @@ public void testPrepareRequest_EmptyContent() throws Exception { restMLCreateConnectorAction.handleRequest(request, channel, client); } + + public void testPrepareRequestFeatureDisabled() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + RestRequest request = getCreateConnectorRestRequest(); + restMLCreateConnectorAction.handleRequest(request, channel, client); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index af56baa13f..ceeda75277 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -8,6 +8,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.TestHelper.getKMeansRestRequest; import static org.opensearch.ml.utils.TestHelper.verifyParsedKMeansMLInput; @@ -36,6 +37,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -57,12 +59,15 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase { RestChannel channel; @Mock MLModelManager modelManager; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; @Before public void setup() { MockitoAnnotations.openMocks(this); when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty()); - restMLPredictionAction = new RestMLPredictionAction(modelManager); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -81,7 +86,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLPredictionAction mlPredictionAction = new RestMLPredictionAction(modelManager); + RestMLPredictionAction mlPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); assertNotNull(mlPredictionAction); } @@ -108,6 +113,15 @@ public void testGetRequest() throws IOException { verifyParsedKMeansMLInput(mlInput); } + public void testGetRequest_RemoteInferenceDisabled() throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + RestRequest request = getRestRequest_PredictModel(); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request); + } + public void testPrepareRequest() throws Exception { RestRequest request = getRestRequest_PredictModel(); restMLPredictionAction.handleRequest(request, channel, client); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index a4f529b1d8..12bc3737fb 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.util.List; @@ -34,10 +35,12 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -62,6 +65,9 @@ public class RestMLRegisterModelActionTests extends OpenSearchTestCase { @Mock private ClusterService clusterService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private Settings settings; @Before @@ -70,7 +76,8 @@ public void setup() { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -87,7 +94,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLRegisterModelAction registerModelAction = new RestMLRegisterModelAction(); + RestMLRegisterModelAction registerModelAction = new RestMLRegisterModelAction(mlFeatureEnabledSetting); assertNotNull(registerModelAction); } @@ -130,11 +137,20 @@ public void testRegisterModelRequest() throws Exception { assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString()); } + public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + RestRequest request = getRestRequestWithNullModelId(); + restMLRegisterModelAction.handleRequest(request, channel, client); + } + public void testRegisterModelUrlNotAllowed() throws Exception { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings); + restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); exceptionRule.expect(IllegalArgumentException.class); exceptionRule .expectMessage( @@ -226,7 +242,9 @@ private RestRequest getRestRequestWithNullModelId() { "model_format", "TORCH_SCRIPT", "model_config", - modelConfig + modelConfig, + "function_name", + FunctionName.REMOTE ); String requestContent = new Gson().toJson(model).toString(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)