From c7c819b6f193a90cd46e9b585e56b50b56717c1d Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Mon, 9 Oct 2023 21:10:32 +0000 Subject: [PATCH] Addressing PR comments (Part 2), adding globalcontexthandler to create components, added updateTemplate(), indexExists() methods to handler and createIndex step respecitvely. Implemented CreateWorkflow/ProvisionWorkflow transport actions Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 13 +++-- .../flowframework/common/CommonValue.java | 8 ++- .../indices/GlobalContextHandler.java | 32 ++++++++++++ .../rest/RestProvisionWorkflowAction.java | 24 +++------ .../CreateWorkflowTransportAction.java | 47 +++++++++-------- .../ProvisionWorkflowTransportAction.java | 51 +++++++++---------- .../workflow/CreateIndexStep.java | 10 ++++ .../FlowFrameworkPluginTests.java | 7 ++- .../RestProvisionWorkflowActionTests.java | 32 ++++-------- 9 files changed, 129 insertions(+), 95 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 6f0ca9000..810674c47 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -18,17 +18,20 @@ import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.flowframework.indices.GlobalContextHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; +import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.ActionPlugin; @@ -76,7 +79,10 @@ public Collection createComponents( WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter); + // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep + GlobalContextHandler globalContextHandler = new GlobalContextHandler(client, new CreateIndexStep(clusterService, client)); + + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, globalContextHandler); } @Override @@ -106,10 +112,9 @@ public List> getExecutorBuilders(Settings settings) { FixedExecutorBuilder provisionThreadPool = new FixedExecutorBuilder( settings, PROVISION_THREAD_POOL, - 1, + OpenSearchExecutors.allocatedProcessors(settings), 10, - FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL, - false + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL ); return ImmutableList.of(provisionThreadPool); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 0aa231ca2..da9aa3d2f 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -24,10 +24,14 @@ public class CommonValue { public static final String AI_FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework"; /** The URI for this plugin's workflow rest actions */ public static final String WORKFLOWS_URI = AI_FLOW_FRAMEWORK_BASE_URI + "/workflows"; + /** Field name for workflow Id, the document Id of the indexed use case template */ + public static final String WORKFLOW_ID = "workflow_id"; + /** The field name for provision workflow within a use case template*/ + public static final String PROVISION_WORKFLOW = "provision"; + /** Flow Framework plugin thread pool name prefix */ public static final String FLOW_FRAMEWORK_THREAD_POOL_PREFIX = "thread_pool.flow_framework."; /** The provision workflow thread pool name */ public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; - /** Field name for workflow Id, the document Id of the indexed use case template */ - public static final String WORKFLOW_ID = "workflow_id"; + } diff --git a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java b/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java index 994cdaeda..6d6565eb5 100644 --- a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; @@ -94,6 +95,37 @@ public void putTemplateToGlobalContext(Template template, ActionListener listener) { + if (!createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String exceptionMessage = String.format( + Locale.ROOT, + "Failed to update template {}, global_context index does not exist.", + documentId + ); + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); + try ( + XContentBuilder builder = XContentFactory.jsonBuilder(); + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() + ) { + request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); + listener.onFailure(e); + } + } + } + /** * Update global context index for specific fields * @param documentId global context index document id diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 2c2b532b4..0c77d654f 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -10,7 +10,6 @@ import com.google.common.collect.ImmutableList; import org.opensearch.client.node.NodeClient; -import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; @@ -31,12 +30,6 @@ public class RestProvisionWorkflowAction extends BaseRestHandler { private static final String PROVISION_WORKFLOW_ACTION = "provision_workflow_action"; - // TODO : move to common values class, pending implementation - /** - * Field name for workflow Id, the document Id of the indexed use case template - */ - public static final String WORKFLOW_ID = "workflow_id"; - /** * Instantiates a new RestProvisionWorkflowAction */ @@ -52,8 +45,6 @@ public String getName() { @Override public List routes() { return ImmutableList.of( - // Provision workflow from inline use case template - new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s", WORKFLOWS_URI, "_provision")), // Provision workflow from indexed use case template new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOWS_URI, WORKFLOW_ID, "_provision")) ); @@ -62,20 +53,19 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String workflowId = request.param(WORKFLOW_ID); - Template template = null; - + // Validate content if (request.hasContent()) { - template = Template.parse(request.content().utf8ToString()); + throw new IOException("Invalid request format"); } - // Validate workflow request inputs - if (workflowId == null && template == null) { - throw new IOException("workflow_id and template cannot be both null"); + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new IOException("workflow_id cannot be null"); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 555b8f767..51a668344 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -12,9 +12,9 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.indices.GlobalContextHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -25,41 +25,46 @@ public class CreateWorkflowTransportAction extends HandledTransportAction listener) { - - String workflowId; - // TODO : Check if global context index exists, and if it does not then create - if (request.getWorkflowId() == null) { - // TODO : Create new entry - // TODO : Insert doc - - // TODO : get document ID - workflowId = ""; - // TODO : check if state index exists, and if it does not, then create - // TODO : insert state index doc, mapped with documentId, defaulted to NOT_STARTED + // Create new global context and state index entries + globalContextHandler.putTemplateToGlobalContext(request.getTemplate(), ActionListener.wrap(response -> { + // TODO : Check if state index exists, create if not + // TODO : Create StateIndexRequest for workflowId, default to NOT_STARTED + listener.onResponse(new WorkflowResponse(response.getId())); + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + listener.onFailure(exception); + })); } else { - // TODO : Update existing entry - workflowId = request.getWorkflowId(); - // TODO : Update state index entry, default back to NOT_STARTED + // Update existing entry, full document replacement + globalContextHandler.updateTemplate(request.getWorkflowId(), request.getTemplate(), ActionListener.wrap(response -> { + // TODO : Create StateIndexRequest for workflowId to reset entry to NOT_STARTED + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + }, exception -> { + logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); + listener.onFailure(exception); + })); } - - listener.onResponse(new WorkflowResponse(workflowId)); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index bf7a725a7..f81a623b7 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -10,10 +10,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -31,7 +33,9 @@ import java.util.concurrent.CompletionException; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; /** * Transport Action to provision a workflow from a stored use case template @@ -40,12 +44,6 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction listener) { - if (request.getWorkflowId() == null) { - // Workflow provisioning from inline template, first parse and then index the given use case template - client.execute(CreateWorkflowAction.INSTANCE, request, ActionListener.wrap(workflowResponse -> { - String workflowId = workflowResponse.getWorkflowId(); - Template template = request.getTemplate(); - - // TODO : Use node client to update state index to PROVISIONING, given workflowId - - listener.onResponse(new WorkflowResponse(workflowId)); - - // Asychronously begin provision workflow excecution - executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW)); + // Retrieve use case template from global context + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); - }, exception -> { listener.onFailure(exception); })); - } else { - // Use case template has been previously saved, retrieve entry and execute - String workflowId = request.getWorkflowId(); + // Stash thread context to interact with system index + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); - // TODO : Retrieve template from global context index using node client - Template template = null; // temporary, remove later + // Parse template from document source + Template template = Template.parse(response.getSourceAsString()); - // TODO : use node client to update state index entry to PROVISIONING, given workflowId + // TODO : Update state index entry to PROVISIONING, given workflowId - listener.onResponse(new WorkflowResponse(workflowId)); - executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW)); + // Respond to rest action then execute provisioning workflow async + listener.onResponse(new WorkflowResponse(workflowId)); + executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW)); + }, exception -> { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + logger.error("Failed to retrieve template from global context.", e); + listener.onFailure(e); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 848f621a2..2b2f7338d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -117,6 +117,16 @@ public String getName() { return NAME; } + // TODO : Move to index management class, pending implementation + /** + * Checks if the given index exists + * @param indexName the name of the index + * @return boolean indicating the existence of an index + */ + public boolean doesIndexExist(String indexName) { + return clusterService.state().metadata().hasIndex(indexName); + } + /** * Create Index if it's absent * @param index The index that needs to be created diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 0d09fe239..76b2bfa6f 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -10,6 +10,7 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -24,6 +25,7 @@ public class FlowFrameworkPluginTests extends OpenSearchTestCase { private Client client; private ThreadPool threadPool; + private Settings settings; @Override public void setUp() throws Exception { @@ -31,6 +33,7 @@ public void setUp() throws Exception { client = mock(Client.class); when(client.admin()).thenReturn(mock(AdminClient.class)); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + settings = Settings.EMPTY; } @Override @@ -41,10 +44,10 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { - assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + assertEquals(3, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); assertEquals(2, ffp.getRestHandlers(null, null, null, null, null, null, null).size()); assertEquals(2, ffp.getActions().size()); - assertEquals(1, ffp.getExecutorBuilders(null).size()); + assertEquals(1, ffp.getExecutorBuilders(settings).size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index f69dfe401..6e868f00d 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -25,25 +25,15 @@ public class RestProvisionWorkflowActionTests extends OpenSearchTestCase { - private String invalidTemplate; private RestProvisionWorkflowAction provisionWorkflowRestAction; - private String provisionInlineWorkflowPath; - private String provisionSavedWorkflowPath; + private String provisionWorkflowPath; private NodeClient nodeClient; @Override public void setUp() throws Exception { super.setUp(); - // Invalid template configuration, missing name field - this.invalidTemplate = "{\"description\":\"description\"," - + "\"use_case\":\"use case\"," - + "\"operations\":[\"operation\"]," - + "\"version\":{\"template\":\"1.0.0\",\"compatibility\":[\"2.0.0\",\"3.0.0\"]}," - + "\"user_inputs\":{\"userKey\":\"userValue\",\"userMapKey\":{\"nestedKey\":\"nestedValue\"}}," - + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; this.provisionWorkflowRestAction = new RestProvisionWorkflowAction(); - this.provisionInlineWorkflowPath = String.format(Locale.ROOT, "%s/%s", WORKFLOWS_URI, "_provision"); - this.provisionSavedWorkflowPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOWS_URI, "workflow_id", "_provision"); + this.provisionWorkflowPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOWS_URI, "workflow_id", "_provision"); this.nodeClient = mock(NodeClient.class); } @@ -54,32 +44,30 @@ public void testRestProvisionWorkflowActionName() { public void testRestProvisiionWorkflowActionRoutes() { List routes = provisionWorkflowRestAction.routes(); - assertEquals(2, routes.size()); + assertEquals(1, routes.size()); assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); - assertEquals(RestRequest.Method.POST, routes.get(1).getMethod()); - assertEquals(this.provisionInlineWorkflowPath, routes.get(0).getPath()); - assertEquals(this.provisionSavedWorkflowPath, routes.get(1).getPath()); + assertEquals(this.provisionWorkflowPath, routes.get(0).getPath()); } public void testNullWorkflowIdAndTemplate() throws IOException { // Request with no content or params RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath(this.provisionInlineWorkflowPath) + .withPath(this.provisionWorkflowPath) .build(); IOException ex = expectThrows(IOException.class, () -> { provisionWorkflowRestAction.prepareRequest(request, nodeClient); }); - assertEquals("workflow_id and template cannot be both null", ex.getMessage()); + assertEquals("workflow_id cannot be null", ex.getMessage()); } - public void testInvalidProvisionInlineWorkflowRequest() throws IOException { + public void testInvalidRequestWithContent() throws IOException { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath(this.provisionInlineWorkflowPath) - .withContent(new BytesArray(invalidTemplate), MediaTypeRegistry.JSON) + .withPath(this.provisionWorkflowPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); IOException ex = expectThrows(IOException.class, () -> { provisionWorkflowRestAction.prepareRequest(request, nodeClient); }); - assertEquals("An template object requires a name.", ex.getMessage()); + assertEquals("Invalid request format", ex.getMessage()); } }