diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index b0a629216..741ce4d4f 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -395,6 +395,17 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL * @param listener action listener */ public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener listener) { + updateTemplateInGlobalContext(documentId, template, listener, false); + } + + /** + * Replaces a document in the global context index + * @param documentId the document Id + * @param template the use-case template + * @param listener action listener + * @param force if set true, ignores the requirement that the provisioning is not started + */ + public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener listener, boolean force) { if (!doesIndexExist(GLOBAL_CONTEXT_INDEX)) { String errorMessage = "Failed to update template for workflow_id : " + documentId + ", global_context index does not exist."; logger.error(errorMessage); @@ -404,7 +415,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, doesTemplateExist(documentId, templateExists -> { if (templateExists) { isWorkflowNotStarted(documentId, workflowIsNotStarted -> { - if (workflowIsNotStarted) { + if (workflowIsNotStarted || force) { IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); try ( XContentBuilder builder = XContentFactory.jsonBuilder(); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 86da71575..c5f52ec71 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -174,7 +174,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { String errorMessage = "Failed to update workflow state: " + workflowId; diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 5e65d8cb1..d6c3ffa65 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -80,7 +80,7 @@ public void testFailedUpdateWorkflow() throws Exception { Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); ResponseException exception = expectThrows(ResponseException.class, () -> updateWorkflow(client(), "123", template)); - assertTrue(exception.getMessage().contains("Failed to get template: 123")); + assertTrue(exception.getMessage().contains("Failed to retrieve template (123) from global context.")); Response response = createWorkflow(client(), template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 3969ec4ee..336e25e23 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -46,6 +46,7 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -153,7 +154,7 @@ public void testProvisionWorkflow() { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any(), anyBoolean()); provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class);