diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index fcf1908d0..061d9f8c8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -36,6 +36,7 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -221,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } + // This method tests code in the abstract parent + public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage()); + } + public void testRegisterLocalCustomModelFailure() { doAnswer(invocation -> {