diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java deleted file mode 100644 index 018783b19..000000000 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.workflow; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.FutureUtils; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTaskState; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.concurrent.CompletableFuture; - -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.TASK_ID; - -/** - * Step to retrieve an ML Task - */ -public class GetMLTaskStep extends AbstractRetryableWorkflowStep { - - private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class); - private MachineLearningNodeClient mlClient; - static final String NAME = "get_ml_task"; - - /** - * Instantiate this class - * @param settings the Opensearch settings - * @param clusterService the OpenSearch cluster service - * @param mlClient client to instantiate MLClient - */ - public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) { - super(settings, clusterService); - this.mlClient = mlClient; - } - - @Override - public CompletableFuture execute( - String currentNodeId, - WorkflowData currentNodeInputs, - Map outputs, - Map previousNodeInputs - ) { - - CompletableFuture getMLTaskFuture = new CompletableFuture<>(); - - String taskId = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case TASK_ID: - taskId = (String) content.get(TASK_ID); - break; - default: - break; - } - } - } - - if (taskId == null) { - logger.error("Failed to retrieve ML Task"); - getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); - } else { - retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId(), getMLTaskFuture, taskId, 0); - } - - return getMLTaskFuture; - } - - @Override - public String getName() { - return NAME; - } - - /** - * Retryable GetMLTask - * @param workflowId the workflow id - * @param nodeId the node id - * @param getMLTaskFuture the workflow step future - * @param taskId the ml task id - * @param retries the current number of request retries - */ - protected void retryableGetMlTask( - String workflowId, - String nodeId, - CompletableFuture getMLTaskFuture, - String taskId, - int retries - ) { - mlClient.getTask(taskId, ActionListener.wrap(response -> { - if (response.getState() != MLTaskState.COMPLETED) { - throw new IllegalStateException("MLTask is not yet completed"); - } else { - logger.info("ML Task retrieval successful"); - getMLTaskFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(MODEL_ID, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); - } - }, exception -> { - if (retries < maxRetry) { - // Sleep thread prior to retrying request - try { - Thread.sleep(5000); - } catch (Exception e) { - FutureUtils.cancel(getMLTaskFuture); - } - final int retryAdd = retries + 1; - retryableGetMlTask(workflowId, nodeId, getMLTaskFuture, taskId, retryAdd); - } else { - logger.error("Failed to retrieve ML Task, maximum retries exceeded"); - getMLTaskFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - } - })); - } - -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java deleted file mode 100644 index a3d1caa4e..000000000 --- a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.workflow; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.common.SuppressForbidden; -import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.MLTaskState; -import org.opensearch.ml.common.transport.task.MLTaskGetResponse; - -/** - * Step to get modelID of a registered local model - */ -@SuppressForbidden(reason = "This class is for the future work of registering local model") -public class GetTask { - - private static final Logger logger = LogManager.getLogger(GetTask.class); - private MachineLearningNodeClient machineLearningNodeClient; - private String taskId; - - /** - * Instantiate this class - * @param machineLearningNodeClient client to instantiate ml-commons APIs - * @param taskId taskID of the model - */ - public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) { - this.machineLearningNodeClient = machineLearningNodeClient; - this.taskId = taskId; - } - - /** - * Invokes get task API of ml-commons - */ - public void getTask() { - - ActionListener actionListener = new ActionListener<>() { - @Override - public void onResponse(MLTask mlTask) { - if (mlTask.getState() == MLTaskState.COMPLETED) { - logger.info("Model registration successful"); - MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build(); - logger.info("Response from task {}", response); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Model registration failed"); - } - }; - - machineLearningNodeClient.getTask(taskId, actionListener); - - } - -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 27aa5e537..19229efd1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -11,10 +11,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -38,17 +42,17 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.TASK_ID; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; /** * Step to register a local model */ -public class RegisterLocalModelStep implements WorkflowStep { +public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); @@ -58,9 +62,12 @@ public class RegisterLocalModelStep implements WorkflowStep { /** * Instantiate this class + * @param settings The OpenSearch settings + * @param clusterService The cluster service * @param mlClient client to instantiate MLClient */ - public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { + public RegisterLocalModelStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) { + super(settings, clusterService); this.mlClient = mlClient; } @@ -74,20 +81,21 @@ public CompletableFuture execute( CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { logger.info("Local Model registration task creation successful"); - registerLocalModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), - Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + + String taskId = mlRegisterModelResponse.getTaskId(); + + // Attempt to retrieve the model ID + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, registerLocalModelFuture, taskId, 0); } @Override @@ -109,12 +117,6 @@ public void onFailure(Exception e) { String allConfig = null; String url = null; - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); @@ -211,4 +213,63 @@ public void onFailure(Exception e) { public String getName() { return NAME; } + + /** + * Retryable get ml task + * @param workflowId the workflow id + * @param nodeId the workflow node id + * @param getMLTaskFuture the workflow step future + * @param taskId the ml task id + * @param retries the current number of request retries + */ + void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture registerLocalModelFuture, + String taskId, + int retries + ) { + mlClient.getTask(taskId, ActionListener.wrap(response -> { + MLTaskState currentState = response.getState(); + if (currentState != MLTaskState.COMPLETED) { + if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { + // Model registration failed or completed with errors + String errorMessage = "Local model registration failed with error : " + response.getError(); + logger.error(errorMessage); + registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + } else { + // Task still in progress, attempt retry + throw new IllegalStateException("Local model registration is not yet completed"); + } + } else { + logger.info("Local model registeration successful"); + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(MODEL_ID, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + } + }, exception -> { + if (retries < maxRetry) { + // Sleep thread prior to retrying request + try { + Thread.sleep(5000); + } catch (Exception e) { + FutureUtils.cancel(registerLocalModelFuture); + } + final int retryAdd = retries + 1; + retryableGetMlTask(workflowId, nodeId, registerLocalModelFuture, taskId, retryAdd); + } else { + logger.error("Failed to retrieve local model registration task, maximum retries exceeded"); + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } + })); + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 2e450d5b0..b95a0449d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -57,12 +57,11 @@ private void populateMap( stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(mlClient)); + stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(settings, clusterService, mlClient)); stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); - stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 5bd88147b..6256189c1 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -52,7 +52,7 @@ "url" ], "outputs":[ - "task_id", + "model_id", "register_model_status" ] }, @@ -83,14 +83,5 @@ "model_group_id", "model_group_status" ] - }, - "get_ml_task": { - "inputs":[ - "task_id" - ], - "outputs":[ - "model_id", - "register_model_status" - ] } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java deleted file mode 100644 index bd62ddfc7..000000000 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.workflow; - -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.MLTaskState; -import org.opensearch.test.OpenSearchTestCase; - -import java.util.Collections; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.TASK_ID; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) -public class GetMLTaskStepTests extends OpenSearchTestCase { - - private GetMLTaskStep getMLTaskStep; - private WorkflowData workflowData; - - @Mock - MachineLearningNodeClient mlNodeClient; - - @Override - public void setUp() throws Exception { - super.setUp(); - - MockitoAnnotations.openMocks(this); - ClusterService clusterService = mock(ClusterService.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - - // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); - ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - - this.getMLTaskStep = spy(new GetMLTaskStep(testMaxRetrySetting, clusterService, mlNodeClient)); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id", "test-node-id"); - } - - public void testGetMLTaskSuccess() throws Exception { - String taskId = "test"; - String modelId = "abcd"; - MLTaskState status = MLTaskState.COMPLETED; - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask(taskId, modelId, null, null, status, null, null, null, null, null, null, null, null, false); - actionListener.onResponse(output); - return null; - }).when(mlNodeClient).getTask(any(), any()); - - CompletableFuture future = this.getMLTaskStep.execute( - workflowData.getNodeId(), - workflowData, - Collections.emptyMap(), - Collections.emptyMap() - ); - - verify(mlNodeClient, times(1)).getTask(any(), any()); - - assertTrue(future.isDone()); - assertTrue(!future.isCompletedExceptionally()); - assertEquals(modelId, future.get().getContent().get(MODEL_ID)); - assertEquals(status.name(), future.get().getContent().get(REGISTER_MODEL_STATUS)); - } - - public void testGetMLTaskFailure() { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IllegalArgumentException("test")); - return null; - }).when(mlNodeClient).getTask(any(), any()); - - CompletableFuture future = this.getMLTaskStep.execute( - workflowData.getNodeId(), - workflowData, - Collections.emptyMap(), - Collections.emptyMap() - ); - assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); - ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); - assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("test", ex.getCause().getMessage()); - } - - public void testMissingInputs() { - CompletableFuture future = this.getMLTaskStep.execute( - "nodeID", - WorkflowData.EMPTY, - Collections.emptyMap(), - Collections.emptyMap() - ); - assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); - ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); - assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); - } - -} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index b7b47de46..d169812a9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -10,29 +10,39 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; -import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.TASK_ID; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterLocalModelStepTests extends OpenSearchTestCase { @@ -48,14 +58,18 @@ public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - this.registerLocalModelStep = new RegisterLocalModelStep(machineLearningNodeClient); + ClusterService clusterService = mock(ClusterService.class); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases + Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); + ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + this.registerLocalModelStep = new RegisterLocalModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient); this.workflowData = new WorkflowData( Map.ofEntries( @@ -79,8 +93,10 @@ public void setUp() throws Exception { public void testRegisterLocalModelSuccess() throws Exception { String taskId = "abcd"; - String status = MLTaskState.CREATED.name(); + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + // Stub register for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); @@ -88,17 +104,42 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + CompletableFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), Collections.emptyMap() ); - verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + ; + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertTrue(future.isDone()); assertTrue(!future.isCompletedExceptionally()); - assertEquals(taskId, future.get().getContent().get(TASK_ID)); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } @@ -124,6 +165,57 @@ public void testRegisterLocalModelFailure() { assertEquals("test", ex.getCause().getMessage()); } + public void testRegisterLocalModelTaskFailure() { + + String taskId = "abcd"; + String modelId = "model-id"; + String testErrorMessage = "error"; + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, MLTaskState.RUNNING.name(), null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub get ml task for failure case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.FAILED, + null, + null, + null, + null, + null, + null, + testErrorMessage, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + CompletableFuture future = this.registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Local model registration failed with error : " + testErrorMessage, ex.getCause().getMessage()); + + } + public void testMissingInputs() { CompletableFuture future = registerLocalModelStep.execute( "nodeId",