From 8b097161319499b029887765cb0813548796d3c7 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 6 Nov 2023 11:19:46 -0800 Subject: [PATCH] adding workflow state and create connector resources created Signed-off-by: Amit Galitzky --- .../flowframework/FlowFrameworkPlugin.java | 5 +- .../flowframework/common/CommonValue.java | 7 +- .../indices/FlowFrameworkIndicesHandler.java | 36 +++++- .../model/ProvisioningProgress.java | 4 +- .../flowframework/model/ResourcesCreated.java | 76 +++++++++++++ .../flowframework/model/Template.java | 32 +++++- .../flowframework/model/WorkflowState.java | 103 ++++++++++++----- .../rest/RestGetWorkflowAction.java | 87 ++++++++++++++ .../transport/CreateWorkflowAction.java | 4 +- .../transport/GetWorkflowAction.java | 27 +++++ .../transport/GetWorkflowResponse.java | 106 ++++++++++++++++++ .../transport/GetWorkflowResponseTest.java | 62 ++++++++++ .../transport/GetWorkflowTransportAction.java | 87 ++++++++++++++ .../transport/ProvisionWorkflowAction.java | 4 +- .../ProvisionWorkflowTransportAction.java | 43 ++++++- .../flowframework/util/ParseUtils.java | 8 ++ .../workflow/CreateConnectorStep.java | 49 +++++++- .../flowframework/workflow/WorkflowData.java | 24 +++- .../workflow/WorkflowProcessSorter.java | 4 +- .../flowframework/workflow/WorkflowStep.java | 1 - .../workflow/WorkflowStepFactory.java | 21 +++- .../resources/mappings/workflow-state.json | 10 +- 22 files changed, 745 insertions(+), 55 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponseTest.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 907bde68b..eb5b07883 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -77,10 +77,9 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ecce8ec50..4822a5f67 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -46,7 +46,7 @@ private CommonValue() {} public static final String USER_FIELD = "user"; /** The transport action name prefix */ - public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; + public static final String TRANSPORT_ACTION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; /** The base URI for this plugin's rest actions */ public static final String FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework"; /** The URI for this plugin's workflow rest actions */ @@ -130,4 +130,9 @@ private CommonValue() {} public static final String USER_OUTPUTS_FIELD = "user_outputs"; /** The template field name for template resources created */ public static final String RESOURCES_CREATED_FIELD = "resources_created"; + /** The field name for the ResourcesCreated's resource ID */ + public static final String RESOURCE_ID_FIELD = "resource_id"; + /** The field name for the ResourcesCreated's resource name */ + public static final String RESOURCE_NAME_FIELD = "resource_type"; + } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 04a3fac5b..f98a649db 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -36,6 +36,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.script.Script; import java.io.IOException; import java.net.URL; @@ -292,7 +293,7 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL .state(State.NOT_STARTED.name()) .provisioningProgress(ProvisioningProgress.NOT_STARTED.name()) .user(user) - .resourcesCreated(Collections.emptyMap()) + .resourcesCreated(Collections.emptyList()) .userOutputs(Collections.emptyMap()) .build(); initWorkflowStateIndexIfAbsent(ActionListener.wrap(indexCreated -> { @@ -381,4 +382,37 @@ public void updateFlowFrameworkSystemIndexDoc( } } } + + /** + * Updates a document in the workflow state index + * @param indexName the index that we will be updating a document of. + * @param documentId the document ID + * @param script the given script to update doc + * @param listener action listener + */ + public void updateFlowFrameworkSystemIndexDocWithScript( + String indexName, + String documentId, + Script script, + ActionListener listener + ) { + if (!doesIndexExist(indexName)) { + String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); + // TODO: Also add ability to change other fields at the same time when adding detailed provision progress + updateRequest.script(script); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + // TODO: decide what condition can be considered as an update conflict and add retry strategy + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); + listener.onFailure(e); + } + } + } + } diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java index d5a2a5734..fe46460c7 100644 --- a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -18,5 +18,7 @@ public enum ProvisioningProgress { /** In Progress State */ IN_PROGRESS, /** Done State */ - DONE + DONE, + /** Failed State */ + FAILED } diff --git a/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java new file mode 100644 index 000000000..9bcd200f4 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/ResourcesCreated.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_NAME_FIELD; + +public class ResourcesCreated implements ToXContentObject, Writeable { + + public String resourceName; + public String resourceId; + + public ResourcesCreated(String resourceName, String resourceId) { + this.resourceName = resourceName; + this.resourceId = resourceId; + } + + public ResourcesCreated(StreamInput input) throws IOException { + this.resourceName = input.readString(); + this.resourceId = input.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject() + .field(RESOURCE_NAME_FIELD, resourceName) + .field(RESOURCE_ID_FIELD, resourceId); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(resourceName); + out.writeString(resourceId); + } + + public static ResourcesCreated parse(XContentParser parser) throws IOException { + String resourceName = null; + String resourceId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case RESOURCE_NAME_FIELD: + resourceName = parser.text(); + break; + case RESOURCE_ID_FIELD: + resourceId = parser.text(); + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + } + } + return new ResourcesCreated(resourceName, resourceId); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a05c374d8..4c2e9d68d 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -13,6 +13,8 @@ import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.yaml.YamlXContent; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -38,7 +40,7 @@ /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. */ -public class Template implements ToXContentObject { +public class Template implements ToXContentObject, Writeable { private final String name; private final String description; @@ -77,6 +79,18 @@ public Template( this.user = user; } + // public Template(StreamInput input) throws IOException { + // this.name = input.readString(); + // this.description = input.readOptionalString(); + // this.useCase = input.readString(); + // this.templateVersion = input.readVersion(); + // this.compatibilityVersion = input.readList(Version::new); // Replace with actual method if different + // this.workflows = input.readMap(StreamInput::readString, WorkFlow::new); // Replace with the actual function to read WorkFlow objects + // if (input.readBoolean()) { + // this.user = new User(input); // Replace with the actual constructor or factory method for User + // } + // } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); @@ -111,6 +125,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(name); + output.writeOptionalString(description); + output.writeString(useCase); + output.writeVersion(templateVersion); + output.writeList((List) compatibilityVersion); + output.writeMapWithConsistentOrder(workflows); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + } + /** * Parse raw json content into a Template instance. * diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index c2b39f0ec..5d589983d 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -9,14 +9,23 @@ package org.opensearch.flowframework.model; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -36,18 +45,18 @@ * The WorkflowState is used to store all additional information regarding a workflow that isn't part of the * global context. */ -public class WorkflowState implements ToXContentObject { +public class WorkflowState implements ToXContentObject, Writeable { private String workflowId; private String error; private String state; - // TODO: Tranisiton the provisioning progress from a string to detailed array of objects + // TODO: Transition the provisioning progress from a string to detailed array of objects private String provisioningProgress; private Instant provisionStartTime; private Instant provisionEndTime; private User user; private Map uiMetadata; private Map userOutputs; - private Map resourcesCreated; + private List resourcesCreated; /** * Instantiate the object representing the workflow state @@ -73,7 +82,7 @@ public WorkflowState( User user, Map uiMetadata, Map userOutputs, - Map resourcesCreated + List resourcesCreated ) { this.workflowId = workflowId; this.error = error; @@ -84,11 +93,24 @@ public WorkflowState( this.user = user; this.uiMetadata = uiMetadata; this.userOutputs = Map.copyOf(userOutputs); - this.resourcesCreated = Map.copyOf(resourcesCreated); + this.resourcesCreated = List.copyOf(resourcesCreated); } private WorkflowState() {} + public WorkflowState(StreamInput input) throws IOException { + this.workflowId = input.readString(); + this.error = input.readOptionalString(); + this.state = input.readOptionalString(); + this.provisioningProgress = input.readOptionalString(); + this.provisionStartTime = input.readOptionalInstant(); + this.provisionEndTime = input.readOptionalInstant(); + this.user = input.readBoolean() ? new User(input) : null; + this.uiMetadata = input.readBoolean() ? input.readMap() : null; + this.userOutputs = input.readBoolean() ? input.readMap() : null; + this.resourcesCreated = input.readList(ResourcesCreated::new); + } + /** * Constructs a builder object for workflowState * @return Builder Object @@ -110,7 +132,7 @@ public static class Builder { private User user = null; private Map uiMetadata = null; private Map userOutputs = null; - private Map resourcesCreated = null; + private List resourcesCreated = null; /** * Empty Constructor for the Builder object @@ -212,8 +234,8 @@ public Builder userOutputs(Map userOutputs) { * @param resourcesCreated resourcesCreated * @return the Builder object */ - public Builder resourcesCreated(Map resourcesCreated) { - this.userOutputs = resourcesCreated; + public Builder resourcesCreated(List resourcesCreated) { + this.resourcesCreated = resourcesCreated; return this; } @@ -268,11 +290,41 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(USER_OUTPUTS_FIELD, userOutputs); } if (resourcesCreated != null && !resourcesCreated.isEmpty()) { - xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated); + xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated.toArray()); } return xContentBuilder.endObject(); } + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(workflowId); + output.writeOptionalString(error); + output.writeOptionalString(state); + output.writeOptionalString(provisioningProgress); + output.writeOptionalInstant(provisionStartTime); + output.writeOptionalInstant(provisionEndTime); + + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + if (uiMetadata != null) { + output.writeBoolean(true); + output.writeMap(uiMetadata); + } else { + output.writeBoolean(false); + } + if (userOutputs != null) { + output.writeBoolean(true); + output.writeMap(userOutputs); + } else { + output.writeBoolean(false); + } + output.writeList(resourcesCreated); + } + /** * Parse raw json content into a Template instance. * @@ -290,7 +342,7 @@ public static WorkflowState parse(XContentParser parser) throws IOException { User user = null; Map uiMetadata = null; Map userOutputs = new HashMap<>(); - Map resourcesCreated = new HashMap<>(); + List resourcesCreated = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -337,23 +389,17 @@ public static WorkflowState parse(XContentParser parser) throws IOException { } } break; - case RESOURCES_CREATED_FIELD: - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String resourcesCreatedField = parser.currentName(); - switch (parser.nextToken()) { - case VALUE_STRING: - resourcesCreated.put(resourcesCreatedField, parser.text()); - break; - case START_OBJECT: - resourcesCreated.put(resourcesCreatedField, parseStringToStringMap(parser)); - break; - default: - throw new IOException( - "Unable to parse field [" + resourcesCreatedField + "] in a resources_created object." - ); + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + resourcesCreated.add(ResourcesCreated.parse(parser)); + } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new FlowFrameworkException("Error parsing newly created resources", RestStatus.INTERNAL_SERVER_ERROR); } + throw e; } break; default: @@ -449,7 +495,12 @@ public Map userOutputs() { * A map of all the resources created * @return the resources created */ - public Map resourcesCreated() { + public List resourcesCreated() { return resourcesCreated; } + + public static WorkflowState fromStream(StreamInput in) throws IOException { + WorkflowState workflowState = new WorkflowState(in); + return workflowState; + } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java new file mode 100644 index 000000000..5666fea78 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.ProvisionWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +public class RestGetWorkflowAction extends BaseRestHandler { + + private static final String GET_WORKFLOW_ACTION = "get_workflow"; + private static final Logger logger = LogManager.getLogger(RestGetWorkflowAction.class); + + /** + * Instantiates a new RestGetWorkflowAction + */ + public RestGetWorkflowAction() {} + + @Override + public String getName() { + return GET_WORKFLOW_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + String rawPath = request.rawPath(); + boolean all = request.paramAsBoolean("_all", false); + + // Create request and provision + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + + // GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( + // detectorId, + // RestActions.parseVersion(request), + // returnJob, + // returnTask, + // typesStr, + // rawPath, + // all, + // buildEntity(request, detectorId) + // ); + // + // return channel -> client + // .execute(GetAnomalyDetectorAction.INSTANCE, getAnomalyDetectorRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList.of( + // Provision workflow from indexed use case template + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java index 0f49c826f..ba9898f1f 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java @@ -10,7 +10,7 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestCreateWorkflowActiom @@ -18,7 +18,7 @@ public class CreateWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/create"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/create"; /** An instance of this action */ public static final CreateWorkflowAction INSTANCE = new CreateWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java new file mode 100644 index 000000000..6abee3867 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowAction + */ +public class GetWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/get"; + /** An instance of this action */ + public static final GetWorkflowAction INSTANCE = new GetWorkflowAction(); + + private GetWorkflowAction() { + super(NAME, GetWorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java new file mode 100644 index 000000000..6a1c3b497 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.IOException; + +import static org.opensearch.flowframework.common.CommonValue.TEMPLATE_FIELD; + +public class GetWorkflowResponse extends ActionResponse implements ToXContentObject { + + public static final String WORKFLOW_STATUS = "workflowStatus"; + private String id; + private WorkflowState workflowState; + private Template template; + private boolean workflowStatus; + + public GetWorkflowResponse(String id, WorkflowState workflowState, Template template, boolean workflowStatus) { + this.id = id; + this.template = template; + this.workflowState = workflowState; + this.workflowStatus = workflowStatus; + } + + /* + if (workflowStatus) { + out.writeBoolean(true); // profileResponse is true + if (workflowState != null) { + out.writeString(WORKFLOW_STATUS); + WorkflowState.writeTo(out); + } + if (template != null) { + out.writeString(TEMPLATE_FIELD); + template.writeTo(out); + } + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + template.writeTo(out); + } + */ + public GetWorkflowResponse(StreamInput in) throws IOException { + super(in); + workflowStatus = in.readBoolean(); + if (workflowStatus) { + if (workflowState != null) { + workflowState = new WorkflowState(in); + } + if (template != null) { + template = new Template(in); + } + } else { + workflowState = null; + id = in.readString(); + template = new Template(in); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (workflowStatus) { + builder.startObject(); + builder.field(WORKFLOW_STATUS, workflowState); + builder.field(TEMPLATE_FIELD, template); + builder.endObject(); + } else { + builder.startObject(); + builder.field("_id", id); + builder.field(TEMPLATE_FIELD, template); + builder.endObject(); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (workflowStatus) { + out.writeBoolean(true); // profileResponse is true + // if (workflowState != null) { + // out.writeString(WORKFLOW_STATUS); + // WorkflowState.writeTo(out); + // } + if (template != null) { + out.writeString(TEMPLATE_FIELD); + template.writeTo(out); + } + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + template.writeTo(out); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponseTest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponseTest.java new file mode 100644 index 000000000..bd1cc05ea --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponseTest.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class GetWorkflowResponseTest extends ActionResponse implements ToXContentObject { + + public WorkflowState workflowState; + + public GetWorkflowResponseTest(StreamInput in) throws IOException { + super(in); + workflowState = workflowState.fromStream(in); + } + + public GetWorkflowResponseTest(WorkflowState workflowState) { + this.workflowState = workflowState; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + workflowState.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return workflowState.toXContent(xContentBuilder, params); + } + + public static GetWorkflowResponseTest fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof GetWorkflowResponseTest) { + return (GetWorkflowResponseTest) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new GetWorkflowResponseTest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into GetWorkflowResponseTest", e); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java new file mode 100644 index 000000000..b76d495f0 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +public class GetWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + @Inject + public GetWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(GetWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + String workflowId = request.getWorkflowId(); + User user = ParseUtils.getUserContext(client); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + logger.debug("Completed Get Workflow Status Request, id:{}", workflowId); + + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowResponseTest(workflowState)); + } catch (Exception e) { + logger.error("Failed to parse workflowState" + r.getId(), e); + listener.onFailure(e); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.BAD_REQUEST)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.BAD_REQUEST)); + } else { + logger.error("Failed to get workflow status of " + workflowId, e); + listener.onFailure(e); + } + }), () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to get workflow: " + workflowId, e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java index 022e73488..30f9f1437 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java @@ -10,14 +10,14 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestProvisionWorkflowAction */ public class ProvisionWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/provision"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/provision"; /** An instance of this action */ public static final ProvisionWorkflowAction INSTANCE = new ProvisionWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 22ac414e5..e6f171b50 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -38,8 +38,10 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; @@ -106,10 +108,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); workflowProcessSorter.validateGraph(provisionProcessSequence); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( @@ -155,13 +158,42 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence) { // TODO : Update Action listener type to State index Request ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { - logger.info("Provisioning completed successuflly for workflow {}", workflowId); + logger.info("Provisioning completed successfully for workflow {}", workflowId); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + WORKFLOW_STATE_INDEX, + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.COMPLETED, + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.DONE, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); - // TODO : Create State index request to update STATE entry status to READY }, exception -> { logger.error("Provisioning failed for workflow {} : {}", workflowId, exception); - - // TODO : Create State index request to update STATE entry status to FAILED + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + WORKFLOW_STATE_INDEX, + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.FAILED, + ERROR_FIELD, + "failed provision", // TODO: improve the error message here + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.FAILED, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage()); }) + ); }); try { threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); @@ -200,7 +232,6 @@ private void executeWorkflow(List workflowSequence, ActionListener< // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally workflowFutureList.forEach(CompletableFuture::join); - // TODO : Create State Index request with provisioning state, start time, end time, etc, pending implementation. String for now workflowListener.onResponse("READY"); } catch (IllegalArgumentException e) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 338f23cdc..9a7c1f00e 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -12,9 +12,12 @@ import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -113,4 +116,9 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 533d82c1e..7b1933e4b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,14 +11,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourcesCreated; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import java.io.IOException; import java.security.AccessController; @@ -41,6 +48,7 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; /** * Step to create a connector for a remote model @@ -50,6 +58,7 @@ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); private MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; static final String NAME = "create_connector"; @@ -57,8 +66,9 @@ public class CreateConnectorStep implements WorkflowStep { * Instantiate this class * @param mlClient client to instantiate MLClient */ - public CreateConnectorStep(MachineLearningNodeClient mlClient) { + public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -69,8 +79,41 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - logger.info("Created connector successfully"); - // TODO Add the response to Global Context + try { + logger.info("Created connector successfully"); + // TODO Add the response to Global Context + String workflowId = data.get(0).getWorkflowId(); + String workflowStepName = getName(); + System.out.println("workflowStepName: " + workflowStepName); + System.out.println("workflowId: " + workflowId); + ResourcesCreated newResource = new ResourcesCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + + // The script to append a new object to the resources_created array + Script script = new Script( + ScriptType.INLINE, + "painless", + "ctx._source.resources_created.add(params.newResource)", + Collections.singletonMap("newResource", newResource) + ); + + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( + WORKFLOW_STATE_INDEX, + workflowId, + script, + ActionListener.wrap(updateResponse -> { + logger.info("updated resources craeted of {}", workflowId); + }, + exception -> { + logger.error("Failed to update workflow state with newly created resource: {}", exception.getMessage()); + } + ) + ); + } catch (IOException e) { + logger.error("Failed to parse new created resource", e); + } + createConnectorFuture.complete( new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) ); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 35ffb7e75..83cd33f90 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -23,9 +23,10 @@ public class WorkflowData { private final Map content; private final Map params; + private String workflowId; private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap()); + this(Collections.emptyMap(), Collections.emptyMap(), ""); } /** @@ -44,6 +45,19 @@ public WorkflowData(Map content) { public WorkflowData(Map content, Map params) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); + this.workflowId = ""; + } + + /** + * Instantiate this object with content and params. + * @param content The content map + * @param params The params map + * @param workflowId The workflow ID associated with this step + */ + public WorkflowData(Map content, Map params, String workflowId) { + this.content = Map.copyOf(content); + this.params = Map.copyOf(params); + this.workflowId = workflowId; } /** @@ -62,4 +76,12 @@ public Map getContent() { public Map getParams() { return this.params; }; + + /** + * Returns the workflowId associated with this workflow. + * @return the workflowId of this data. + */ + public String getWorkflowId() { + return this.workflowId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 745de5921..33aa0d191 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -61,14 +61,14 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public List sortProcessNodes(Workflow workflow) { + public List sortProcessNodes(Workflow workflow, String workflowId) { List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams()); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index bdbe63caa..1f5545cdf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -30,5 +30,4 @@ public interface WorkflowStep { * @return the name of this workflow step. */ String getName(); - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c30bdf87c..c8f20cd95 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; @@ -23,6 +24,7 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -32,17 +34,28 @@ public class WorkflowStepFactory { * @param mlClient Machine Learning client to perform ml operations */ - public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { - populateMap(clusterService, client, mlClient); + public WorkflowStepFactory( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler); } - private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + private void populateMap( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); } diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 86fbeef6e..21df5ccd6 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,7 +31,15 @@ "type": "object" }, "resources_created": { - "type": "object" + "type": "nested", + "properties": { + "workflow_step_name": { + "type": "keyword" + }, + "resource_id": { + "type": "keyword" + } + } }, "ui_metadata": { "type": "object",