diff --git a/CHANGELOG.md b/CHANGELOG.md index bdaed0ea0..f5d147ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,3 +24,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Documentation ### Maintenance ### Refactoring +- Refactor workflow step resource updates to eliminate duplication ([#796](https://github.com/opensearch-project/flow-framework/pull/796)) diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 0d583e0a1..c5ac91eba 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -637,7 +637,7 @@ public void deleteFlowFrameworkSystemIndexDoc(String documentId, ActionListener< * @param script the given script to update doc * @param listener action listener */ - public void updateFlowFrameworkSystemIndexDocWithScript( + private void updateFlowFrameworkSystemIndexDocWithScript( String indexName, String documentId, Script script, @@ -667,13 +667,13 @@ public void updateFlowFrameworkSystemIndexDocWithScript( /** * Creates a new ResourceCreated object and a script to update the state index * @param workflowId workflowId for the relevant step - * @param nodeId WorkflowData object with relevent step information + * @param nodeId current process node (workflow step) id * @param workflowStepName the workflowstep name that created the resource * @param resourceId the id of the newly created resource * @param listener the ActionListener for this step to handle completing the future after update * @throws IOException if parsing fails on new resource */ - public void updateResourceInStateIndex( + private void updateResourceInStateIndex( String workflowId, String nodeId, String workflowStepName, @@ -704,7 +704,7 @@ public void updateResourceInStateIndex( /** * Adds a resource to the state index, including common exception handling * @param currentNodeInputs Inputs to the current node - * @param nodeId WorkflowData object with relevent step information + * @param nodeId current process node (workflow step) id * @param workflowStepName the workflow step name that created the resource * @param resourceId the id of the newly created resource * @param listener the ActionListener for this step to handle completing the future after update diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index f55cb3bc1..ecaec46b5 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -37,13 +37,16 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.CreateIndexStep; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.index.get.GetResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -488,4 +491,52 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { exceptionCaptor.getValue().getMessage() ); } + + public void testAddResourceToStateIndex() throws IOException { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED)); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to update state")); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update new created node_id resource create_connector id this_id", exceptionCaptor.getValue().getMessage()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 4aabee215..b53574cfe 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -15,6 +15,9 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.flowframework.common.CommonValue.DELAY_FIELD; @@ -50,6 +53,49 @@ public void testNoOpStepDelay() throws IOException, InterruptedException { assertTrue(System.nanoTime() - start > 900_000_000L); } + public void testNoOpStepInterrupt() throws IOException, InterruptedException { + NoOpStep noopStep = new NoOpStep(); + WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "5s"), null, null); + + CountDownLatch latch = new CountDownLatch(1); + // Fetch errors from the separate thread + AtomicReference assertionError = new AtomicReference<>(); + + Thread testThread = new Thread(() -> { + try { + PlainActionFuture future = noopStep.execute( + "nodeId", + delayData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + try { + future.actionGet(); + } catch (Exception e) { + // Ignore the IllegalStateExcption/InterruptedExcpetion + } + assertTrue(future.isDone()); + assertTrue(future.isCancelled()); + assertTrue(Thread.currentThread().isInterrupted()); + } catch (AssertionError e) { + assertionError.set(e); + } finally { + latch.countDown(); + } + }); + + testThread.start(); + Thread.sleep(100); + testThread.interrupt(); + + latch.await(1, TimeUnit.SECONDS); + + if (assertionError.get() != null) { + throw assertionError.get(); + } + } + public void testNoOpStepParse() throws IOException { NoOpStep noopStep = new NoOpStep(); WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "foo"), null, null); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index fcf1908d0..061d9f8c8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -36,6 +36,7 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -221,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } + // This method tests code in the abstract parent + public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage()); + } + public void testRegisterLocalCustomModelFailure() { doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index b0ee2e4bf..0e2ab91e9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -224,6 +225,99 @@ public void testRegisterRemoteModelFailure() { } + public void testRegisterRemoteModelUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onFailure(new RuntimeException("Failed to update register resource")); + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update new created test-node-id resource register_remote_model id efgh", ex.getCause().getMessage()); + } + + public void testRegisterRemoteModelDeployUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource efgh", ex.getCause().getMessage()); + } + public void testReisterRemoteModelInterfaceFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1);