From da29f8467baa64c4969827ce35d505633ff6cf31 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 6 Nov 2023 11:19:46 -0800 Subject: [PATCH] adding workflow state and create connector resources created Signed-off-by: Amit Galitzky --- .../flowframework/FlowFrameworkPlugin.java | 13 ++- .../flowframework/common/CommonValue.java | 7 +- .../indices/FlowFrameworkIndicesHandler.java | 36 +++++- .../model/ProvisioningProgress.java | 4 +- .../flowframework/model/ResourcesCreated.java | 96 ++++++++++++++++ .../flowframework/model/Template.java | 21 +++- .../flowframework/model/WorkflowState.java | 105 +++++++++++++----- .../rest/RestCreateWorkflowAction.java | 2 +- .../rest/RestGetWorkflowAction.java | 74 ++++++++++++ .../transport/CreateWorkflowAction.java | 4 +- .../transport/GetWorkflowAction.java | 27 +++++ .../transport/GetWorkflowResponse.java | 64 +++++++++++ .../transport/GetWorkflowTransportAction.java | 99 +++++++++++++++++ .../transport/ProvisionWorkflowAction.java | 4 +- .../ProvisionWorkflowTransportAction.java | 43 ++++++- .../transport/WorkflowRequest.java | 25 +++++ .../flowframework/util/ParseUtils.java | 16 +++ .../workflow/CreateConnectorStep.java | 47 +++++++- .../flowframework/workflow/WorkflowData.java | 24 +++- .../workflow/WorkflowProcessSorter.java | 5 +- .../flowframework/workflow/WorkflowStep.java | 1 - .../workflow/WorkflowStepFactory.java | 23 +++- .../resources/mappings/workflow-state.json | 10 +- .../FlowFrameworkPluginTests.java | 2 +- .../model/WorkflowValidatorTests.java | 4 +- .../workflow/CreateConnectorStepTests.java | 10 +- .../workflow/WorkflowProcessSorterTests.java | 10 +- 27 files changed, 708 insertions(+), 68 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 907bde68b..dbe55cf91 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -26,9 +26,12 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.GetWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; @@ -77,10 +80,9 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); } @@ -95,14 +97,15 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - return ImmutableList.of(new RestCreateWorkflowAction(), new RestProvisionWorkflowAction()); + return ImmutableList.of(new RestCreateWorkflowAction(), new RestProvisionWorkflowAction(), new RestGetWorkflowAction()); } @Override public List> getActions() { return ImmutableList.of( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), - new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class) + new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), + new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ecce8ec50..4822a5f67 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -46,7 +46,7 @@ private CommonValue() {} public static final String USER_FIELD = "user"; /** The transport action name prefix */ - public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; + public static final String TRANSPORT_ACTION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; /** The base URI for this plugin's rest actions */ public static final String FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework"; /** The URI for this plugin's workflow rest actions */ @@ -130,4 +130,9 @@ private CommonValue() {} public static final String USER_OUTPUTS_FIELD = "user_outputs"; /** The template field name for template resources created */ public static final String RESOURCES_CREATED_FIELD = "resources_created"; + /** The field name for the ResourcesCreated's resource ID */ + public static final String RESOURCE_ID_FIELD = "resource_id"; + /** The field name for the ResourcesCreated's resource name */ + public static final String RESOURCE_NAME_FIELD = "resource_type"; + } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 04a3fac5b..f98a649db 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -36,6 +36,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.script.Script; import java.io.IOException; import java.net.URL; @@ -292,7 +293,7 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL .state(State.NOT_STARTED.name()) .provisioningProgress(ProvisioningProgress.NOT_STARTED.name()) .user(user) - .resourcesCreated(Collections.emptyMap()) + .resourcesCreated(Collections.emptyList()) .userOutputs(Collections.emptyMap()) .build(); initWorkflowStateIndexIfAbsent(ActionListener.wrap(indexCreated -> { @@ -381,4 +382,37 @@ public void updateFlowFrameworkSystemIndexDoc( } } } + + /** + * Updates a document in the workflow state index + * @param indexName the index that we will be updating a document of. + * @param documentId the document ID + * @param script the given script to update doc + * @param listener action listener + */ + public void updateFlowFrameworkSystemIndexDocWithScript( + String indexName, + String documentId, + Script script, + ActionListener listener + ) { + if (!doesIndexExist(indexName)) { + String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); + // TODO: Also add ability to change other fields at the same time when adding detailed provision progress + updateRequest.script(script); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + // TODO: decide what condition can be considered as an update conflict and add retry strategy + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); + listener.onFailure(e); + } + } + } + } diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java index d5a2a5734..fe46460c7 100644 --- a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -18,5 +18,7 @@ public enum ProvisioningProgress { /** In Progress State */ IN_PROGRESS, /** Done State */ - DONE + DONE, + /** Failed State */ + FAILED } diff --git a/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java new file mode 100644 index 000000000..6fe9cfd82 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_NAME_FIELD; + +/** + * This represents an object in the WorkflowState {@link WorkflowState}. + */ +public class ResourcesCreated implements ToXContentObject, Writeable { + + private String resourceName; + private String resourceId; + + /** + * Create this resources created object with given resource name and ID. + * @param resourceName The resource name associating to the step name where it was created + * @param resourceId The resources ID for relating to the created resource + */ + public ResourcesCreated(String resourceName, String resourceId) { + this.resourceName = resourceName; + this.resourceId = resourceId; + } + + /** + * Create this resources created object with an StreamInput + * @param input the input stream to read from + * @throws IOException if failed to read input stream + */ + public ResourcesCreated(StreamInput input) throws IOException { + this.resourceName = input.readString(); + this.resourceId = input.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject() + .field(RESOURCE_NAME_FIELD, resourceName) + .field(RESOURCE_ID_FIELD, resourceId); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(resourceName); + out.writeString(resourceId); + } + + /** + * Parse raw JSON content into a resourcesCreated instance. + * + * @param parser JSON based content parser + * @return the parsed ResourcesCreated instance + * @throws IOException if content can't be parsed correctly + */ + public static ResourcesCreated parse(XContentParser parser) throws IOException { + String resourceName = null; + String resourceId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case RESOURCE_NAME_FIELD: + resourceName = parser.text(); + break; + case RESOURCE_ID_FIELD: + resourceId = parser.text(); + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + } + } + return new ResourcesCreated(resourceName, resourceId); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a05c374d8..f143be51f 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -13,6 +13,8 @@ import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.yaml.YamlXContent; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -38,7 +40,7 @@ /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. */ -public class Template implements ToXContentObject { +public class Template implements ToXContentObject, Writeable { private final String name; private final String description; @@ -111,6 +113,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } + // TODO: fix writeable when implementing get workflow API + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(name); + output.writeOptionalString(description); + output.writeString(useCase); + output.writeVersion(templateVersion); + // output.writeList((List) compatibilityVersion); + output.writeMapWithConsistentOrder(workflows); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + } + /** * Parse raw json content into a Template instance. * diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index c2b39f0ec..2703c4738 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -9,14 +9,23 @@ package org.opensearch.flowframework.model; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -36,18 +45,18 @@ * The WorkflowState is used to store all additional information regarding a workflow that isn't part of the * global context. */ -public class WorkflowState implements ToXContentObject { +public class WorkflowState implements ToXContentObject, Writeable { private String workflowId; private String error; private String state; - // TODO: Tranisiton the provisioning progress from a string to detailed array of objects + // TODO: Transition the provisioning progress from a string to detailed array of objects private String provisioningProgress; private Instant provisionStartTime; private Instant provisionEndTime; private User user; private Map uiMetadata; private Map userOutputs; - private Map resourcesCreated; + private List resourcesCreated; /** * Instantiate the object representing the workflow state @@ -73,7 +82,7 @@ public WorkflowState( User user, Map uiMetadata, Map userOutputs, - Map resourcesCreated + List resourcesCreated ) { this.workflowId = workflowId; this.error = error; @@ -84,11 +93,30 @@ public WorkflowState( this.user = user; this.uiMetadata = uiMetadata; this.userOutputs = Map.copyOf(userOutputs); - this.resourcesCreated = Map.copyOf(resourcesCreated); + this.resourcesCreated = List.copyOf(resourcesCreated); } private WorkflowState() {} + /** + * Instatiates a new WorkflowState from an input stream + * @param input the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public WorkflowState(StreamInput input) throws IOException { + this.workflowId = input.readString(); + this.error = input.readOptionalString(); + this.state = input.readOptionalString(); + this.provisioningProgress = input.readOptionalString(); + this.provisionStartTime = input.readOptionalInstant(); + this.provisionEndTime = input.readOptionalInstant(); + // TODO: fix error: cannot access Response issue when integrating with access control + // this.user = input.readBoolean() ? new User(input) : null; + this.uiMetadata = input.readBoolean() ? input.readMap() : null; + this.userOutputs = input.readBoolean() ? input.readMap() : null; + this.resourcesCreated = input.readList(ResourcesCreated::new); + } + /** * Constructs a builder object for workflowState * @return Builder Object @@ -110,7 +138,7 @@ public static class Builder { private User user = null; private Map uiMetadata = null; private Map userOutputs = null; - private Map resourcesCreated = null; + private List resourcesCreated = null; /** * Empty Constructor for the Builder object @@ -212,8 +240,8 @@ public Builder userOutputs(Map userOutputs) { * @param resourcesCreated resourcesCreated * @return the Builder object */ - public Builder resourcesCreated(Map resourcesCreated) { - this.userOutputs = resourcesCreated; + public Builder resourcesCreated(List resourcesCreated) { + this.resourcesCreated = resourcesCreated; return this; } @@ -268,11 +296,41 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(USER_OUTPUTS_FIELD, userOutputs); } if (resourcesCreated != null && !resourcesCreated.isEmpty()) { - xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated); + xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated.toArray()); } return xContentBuilder.endObject(); } + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(workflowId); + output.writeOptionalString(error); + output.writeOptionalString(state); + output.writeOptionalString(provisioningProgress); + output.writeOptionalInstant(provisionStartTime); + output.writeOptionalInstant(provisionEndTime); + + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + if (uiMetadata != null) { + output.writeBoolean(true); + output.writeMap(uiMetadata); + } else { + output.writeBoolean(false); + } + if (userOutputs != null) { + output.writeBoolean(true); + output.writeMap(userOutputs); + } else { + output.writeBoolean(false); + } + output.writeList(resourcesCreated); + } + /** * Parse raw json content into a Template instance. * @@ -290,7 +348,7 @@ public static WorkflowState parse(XContentParser parser) throws IOException { User user = null; Map uiMetadata = null; Map userOutputs = new HashMap<>(); - Map resourcesCreated = new HashMap<>(); + List resourcesCreated = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -337,23 +395,17 @@ public static WorkflowState parse(XContentParser parser) throws IOException { } } break; - case RESOURCES_CREATED_FIELD: - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String resourcesCreatedField = parser.currentName(); - switch (parser.nextToken()) { - case VALUE_STRING: - resourcesCreated.put(resourcesCreatedField, parser.text()); - break; - case START_OBJECT: - resourcesCreated.put(resourcesCreatedField, parseStringToStringMap(parser)); - break; - default: - throw new IOException( - "Unable to parse field [" + resourcesCreatedField + "] in a resources_created object." - ); + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + resourcesCreated.add(ResourcesCreated.parse(parser)); } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new FlowFrameworkException("Error parsing newly created resources", RestStatus.INTERNAL_SERVER_ERROR); + } + throw e; } break; default: @@ -449,7 +501,8 @@ public Map userOutputs() { * A map of all the resources created * @return the resources created */ - public Map resourcesCreated() { + public List resourcesCreated() { return resourcesCreated; } + } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index ace440f75..b76d1adcd 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -32,7 +32,7 @@ public class RestCreateWorkflowAction extends BaseRestHandler { private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; /** - * Intantiates a new RestCreateWorkflowAction + * Instantiates a new RestCreateWorkflowAction */ public RestCreateWorkflowAction() {} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java new file mode 100644 index 000000000..1f1a9295d --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +/** + * Rest Action to facilitate requests to get a workflow status + */ +public class RestGetWorkflowAction extends BaseRestHandler { + + private static final String GET_WORKFLOW_ACTION = "get_workflow"; + private static final Logger logger = LogManager.getLogger(RestGetWorkflowAction.class); + + /** + * Instantiates a new RestGetWorkflowAction + */ + public RestGetWorkflowAction() {} + + @Override + public String getName() { + return GET_WORKFLOW_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + String rawPath = request.rawPath(); + boolean all = request.paramAsBoolean("_all", false); + // Create request and provision + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, all); + return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList.of( + // Provision workflow from indexed use case template + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java index 0f49c826f..ba9898f1f 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java @@ -10,7 +10,7 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestCreateWorkflowActiom @@ -18,7 +18,7 @@ public class CreateWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/create"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/create"; /** An instance of this action */ public static final CreateWorkflowAction INSTANCE = new CreateWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java new file mode 100644 index 000000000..6abee3867 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowAction + */ +public class GetWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/get"; + /** An instance of this action */ + public static final GetWorkflowAction INSTANCE = new GetWorkflowAction(); + + private GetWorkflowAction() { + super(NAME, GetWorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java new file mode 100644 index 000000000..9686201fa --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.IOException; + +/** + * Transport Response from getting a workflow status + */ +public class GetWorkflowResponse extends ActionResponse implements ToXContentObject { + + public WorkflowState workflowState; + public boolean allStatus; + + /** + * Instantiates a new GetWorkflowResponse from an input stream + * @param in the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public GetWorkflowResponse(StreamInput in) throws IOException { + super(in); + workflowState = new WorkflowState(in); + allStatus = false; + } + + /** + * Instatiates a new GetWorkflowResponse from an input stream + * @param workflowState the workflow state object + * @param allStatus whether to return all fields in state index + */ + public GetWorkflowResponse(WorkflowState workflowState, boolean allStatus) { + if (allStatus) { + this.workflowState = workflowState; + } else { + this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) + .error(workflowState.getError()) + .state(workflowState.getState()) + .build(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + workflowState.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return workflowState.toXContent(xContentBuilder, params); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java new file mode 100644 index 000000000..3471540cc --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +//TODO: Currently we only get the workflow status but we should change to be able to get the +// full template as well +/** + * Transport Action to get status of a current workflow + */ +public class GetWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Intantiates a new CreateWorkflowTransportAction + * @param transportService the TransportService + * @param actionFilters action filters + * @param client The client used to make the request to OS + * @param xContentRegistry contentRegister to parse get response + */ + @Inject + public GetWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(GetWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + String workflowId = request.getWorkflowId(); + User user = ParseUtils.getUserContext(client); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + logger.debug("Completed Get Workflow Status Request, id:{}", workflowId); + + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll())); + } catch (Exception e) { + logger.error("Failed to parse workflowState" + r.getId(), e); + listener.onFailure(e); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.BAD_REQUEST)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.BAD_REQUEST)); + } else { + logger.error("Failed to get workflow status of " + workflowId, e); + listener.onFailure(e); + } + }), () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to get workflow: " + workflowId, e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java index 022e73488..30f9f1437 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java @@ -10,14 +10,14 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestProvisionWorkflowAction */ public class ProvisionWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/provision"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/provision"; /** An instance of this action */ public static final ProvisionWorkflowAction INSTANCE = new ProvisionWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 22ac414e5..e6f171b50 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -38,8 +38,10 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; @@ -106,10 +108,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); workflowProcessSorter.validateGraph(provisionProcessSequence); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( @@ -155,13 +158,42 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence) { // TODO : Update Action listener type to State index Request ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { - logger.info("Provisioning completed successuflly for workflow {}", workflowId); + logger.info("Provisioning completed successfully for workflow {}", workflowId); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + WORKFLOW_STATE_INDEX, + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.COMPLETED, + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.DONE, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); - // TODO : Create State index request to update STATE entry status to READY }, exception -> { logger.error("Provisioning failed for workflow {} : {}", workflowId, exception); - - // TODO : Create State index request to update STATE entry status to FAILED + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + WORKFLOW_STATE_INDEX, + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.FAILED, + ERROR_FIELD, + "failed provision", // TODO: improve the error message here + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.FAILED, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage()); }) + ); }); try { threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); @@ -200,7 +232,6 @@ private void executeWorkflow(List workflowSequence, ActionListener< // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally workflowFutureList.forEach(CompletableFuture::join); - // TODO : Create State Index request with provisioning state, start time, end time, etc, pending implementation. String for now workflowListener.onResponse("READY"); } catch (IllegalArgumentException e) { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 0b105552f..da316250e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -32,6 +32,10 @@ public class WorkflowRequest extends ActionRequest { */ @Nullable private Template template; + /** + * The all parameter on the get request + */ + private boolean all; /** * Instantiates a new WorkflowRequest @@ -41,6 +45,19 @@ public class WorkflowRequest extends ActionRequest { public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { this.workflowId = workflowId; this.template = template; + this.all = false; + } + + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param all whether the get request is looking for all fields in status + */ + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, boolean all) { + this.workflowId = workflowId; + this.template = template; + this.all = all; } /** @@ -73,6 +90,14 @@ public Template getTemplate() { return this.template; } + /** + * Gets the value of the all parameter + * @return whether the all parameter was present or not in request + */ + public boolean getAll() { + return this.all; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 338f23cdc..0f725f687 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -12,9 +12,12 @@ import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -113,4 +116,17 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + /** + * Creates a XContentParser from a given Registry + * + * @param xContentRegistry main registry for serializable content + * @param bytesReference given bytes to be parsed + * @return bytesReference of {@link java.time.Instant} + * @throws IOException IOException if content can't be parsed correctly + */ + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 533d82c1e..5ca0befae 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,14 +11,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourcesCreated; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import java.io.IOException; import java.security.AccessController; @@ -41,6 +48,7 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; /** * Step to create a connector for a remote model @@ -50,15 +58,18 @@ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); private MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; static final String NAME = "create_connector"; /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateConnectorStep(MachineLearningNodeClient mlClient) { + public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -69,8 +80,38 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - logger.info("Created connector successfully"); - // TODO Add the response to Global Context + try { + logger.info("Created connector successfully"); + String workflowId = data.get(0).getWorkflowId(); + String workflowStepName = getName(); + ResourcesCreated newResource = new ResourcesCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + + // The script to append a new object to the resources_created array + Script script = new Script( + ScriptType.INLINE, + "painless", + "ctx._source.resources_created.add(params.newResource)", + Collections.singletonMap("newResource", newResource) + ); + + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( + WORKFLOW_STATE_INDEX, + workflowId, + script, + ActionListener.wrap(updateResponse -> { + logger.info("updated resources craeted of {}", workflowId); + }, + exception -> { + logger.error("Failed to update workflow state with newly created resource: {}", exception.getMessage()); + } + ) + ); + } catch (IOException e) { + logger.error("Failed to parse new created resource", e); + } + createConnectorFuture.complete( new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) ); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 35ffb7e75..83cd33f90 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -23,9 +23,10 @@ public class WorkflowData { private final Map content; private final Map params; + private String workflowId; private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap()); + this(Collections.emptyMap(), Collections.emptyMap(), ""); } /** @@ -44,6 +45,19 @@ public WorkflowData(Map content) { public WorkflowData(Map content, Map params) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); + this.workflowId = ""; + } + + /** + * Instantiate this object with content and params. + * @param content The content map + * @param params The params map + * @param workflowId The workflow ID associated with this step + */ + public WorkflowData(Map content, Map params, String workflowId) { + this.content = Map.copyOf(content); + this.params = Map.copyOf(params); + this.workflowId = workflowId; } /** @@ -62,4 +76,12 @@ public Map getContent() { public Map getParams() { return this.params; }; + + /** + * Returns the workflowId associated with this workflow. + * @return the workflowId of this data. + */ + public String getWorkflowId() { + return this.workflowId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 745de5921..18ca8a53e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -59,16 +59,17 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool /** * Sort a workflow into a topologically sorted list of process nodes. * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors + * @param workflowId The workflowId associated with the step * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public List sortProcessNodes(Workflow workflow) { + public List sortProcessNodes(Workflow workflow, String workflowId) { List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams()); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index bdbe63caa..1f5545cdf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -30,5 +30,4 @@ public interface WorkflowStep { * @return the name of this workflow step. */ String getName(); - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c30bdf87c..22b5f07c3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; @@ -23,6 +24,7 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -30,19 +32,30 @@ public class WorkflowStepFactory { * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use * @param mlClient Machine Learning client to perform ml operations + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - - public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { - populateMap(clusterService, client, mlClient); + public WorkflowStepFactory( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler); } - private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + private void populateMap( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); } diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 86fbeef6e..21df5ccd6 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,7 +31,15 @@ "type": "object" }, "resources_created": { - "type": "object" + "type": "nested", + "properties": { + "workflow_step_name": { + "type": "keyword" + }, + "resource_id": { + "type": "keyword" + } + } }, "ui_metadata": { "type": "object", diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index cbbc6d7fe..afb3a7469 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -54,7 +54,7 @@ public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals(3, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); assertEquals(2, ffp.getRestHandlers(null, null, null, null, null, null, null).size()); - assertEquals(2, ffp.getActions().size()); + assertEquals(3, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 6c474a11e..28e7e0585 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -13,6 +13,7 @@ import org.opensearch.client.ClusterAdminClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; @@ -67,8 +68,9 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { when(client.admin()).thenReturn(adminClient); when(adminClient.cluster()).thenReturn(clusterAdminClient); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); // Read in workflow-steps.json WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json"); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 63855f7bd..111b787f8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -12,6 +12,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -30,6 +31,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; public class CreateConnectorStepTests extends OpenSearchTestCase { @@ -38,10 +40,12 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); Map[] actions = new Map[] { @@ -73,7 +77,7 @@ public void setUp() throws Exception { public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { String connectorId = "connect"; - CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -95,7 +99,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr } public void testCreateConnectorFailure() throws IOException { - CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 9f629ff9e..1840259ad 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -50,7 +51,7 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase { private static List parseToNodes(String json) throws IOException { XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); Workflow w = Workflow.parse(parser); - return workflowProcessSorter.sortProcessNodes(w); + return workflowProcessSorter.sortProcessNodes(w, "123"); } // Wrap parser into string list @@ -67,11 +68,12 @@ public static void setup() { ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } @@ -245,7 +247,7 @@ public void testSuccessfulGraphValidation() throws Exception { Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); workflowProcessSorter.validateGraph(sortedProcessNodes); } @@ -267,7 +269,7 @@ public void testFailedGraphValidation() { WorkflowEdge edge = new WorkflowEdge(registerModel.id(), deployModel.id()); Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, () -> workflowProcessSorter.validateGraph(sortedProcessNodes)