diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6f3f8302f..a47ed684f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,6 +14,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + # Spotless requires JDK 17+ + - uses: actions/setup-java@v3 + with: + java-version: 17 + distribution: temurin - name: Spotless Check run: ./gradlew spotlessCheck javadoc: @@ -46,7 +51,7 @@ jobs: distribution: temurin - name: Build and Run Tests run: | - ./gradlew check -x integTest -x yamlRestTest + ./gradlew check -x integTest -x yamlRestTest -x spotlessJava - name: Upload Coverage Report if: ${{ matrix.codecov }} uses: codecov/codecov-action@v3 diff --git a/build.gradle b/build.gradle index 57cf13702..1b854d91e 100644 --- a/build.gradle +++ b/build.gradle @@ -146,6 +146,7 @@ dependencies { configurations.all { resolutionStrategy { force("com.google.guava:guava:32.1.3-jre") // CVE for 31.1 + force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // CVE for 3.26.100 } } } diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index f44b72495..42ad04213 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -28,10 +28,13 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.GetWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; @@ -92,10 +95,9 @@ public Collection createComponents( flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); } @@ -113,7 +115,8 @@ public List getRestHandlers( return ImmutableList.of( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), - new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting) + new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting) ); } @@ -122,7 +125,8 @@ public List getRestHandlers( return ImmutableList.of( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), - new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class) + new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), + new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ac2cb2a86..1f9f33f4b 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -46,7 +46,7 @@ private CommonValue() {} public static final String USER_FIELD = "user"; /** The transport action name prefix */ - public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; + public static final String TRANSPORT_ACTION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; /** The base URI for this plugin's rest actions */ public static final String FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework"; /** The URI for this plugin's workflow rest actions */ @@ -148,4 +148,9 @@ private CommonValue() {} public static final String USER_OUTPUTS_FIELD = "user_outputs"; /** The template field name for template resources created */ public static final String RESOURCES_CREATED_FIELD = "resources_created"; + /** The field name for the ResourceCreated's resource ID */ + public static final String RESOURCE_ID_FIELD = "resource_id"; + /** The field name for the ResourceCreated's resource name */ + public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; + } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 00c029a3e..7dcc89de6 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.script.Script; import java.io.IOException; import java.net.URL; @@ -311,7 +312,7 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL .state(State.NOT_STARTED.name()) .provisioningProgress(ProvisioningProgress.NOT_STARTED.name()) .user(user) - .resourcesCreated(Collections.emptyMap()) + .resourcesCreated(Collections.emptyList()) .userOutputs(Collections.emptyMap()) .build(); initWorkflowStateIndexIfAbsent(ActionListener.wrap(indexCreated -> { @@ -402,4 +403,38 @@ public void updateFlowFrameworkSystemIndexDoc( } } } + + /** + * Updates a document in the workflow state index + * @param indexName the index that we will be updating a document of. + * @param documentId the document ID + * @param script the given script to update doc + * @param listener action listener + */ + public void updateFlowFrameworkSystemIndexDocWithScript( + String indexName, + String documentId, + Script script, + ActionListener listener + ) { + if (!doesIndexExist(indexName)) { + String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); + // TODO: Also add ability to change other fields at the same time when adding detailed provision progress + updateRequest.script(script); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + // TODO: decide what condition can be considered as an update conflict and add retry strategy + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); + listener.onFailure( + new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e)) + ); + } + } + } } diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java index d5a2a5734..fe46460c7 100644 --- a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -18,5 +18,7 @@ public enum ProvisioningProgress { /** In Progress State */ IN_PROGRESS, /** Done State */ - DONE + DONE, + /** Failed State */ + FAILED } diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java new file mode 100644 index 000000000..0ec8f34d5 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_NAME; + +/** + * This represents an object in the WorkflowState {@link WorkflowState}. + */ +// TODO: create an enum to add the resource name itself for each step example (create_connector_step -> connector) +public class ResourceCreated implements ToXContentObject, Writeable { + + private final String workflowStepName; + private final String resourceId; + + /** + * Create this resources created object with given resource name and ID. + * @param workflowStepName The workflow step name associating to the step where it was created + * @param resourceId The resources ID for relating to the created resource + */ + public ResourceCreated(String workflowStepName, String resourceId) { + this.workflowStepName = workflowStepName; + this.resourceId = resourceId; + } + + /** + * Create this resources created object with an StreamInput + * @param input the input stream to read from + * @throws IOException if failed to read input stream + */ + public ResourceCreated(StreamInput input) throws IOException { + this.workflowStepName = input.readString(); + this.resourceId = input.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject() + .field(WORKFLOW_STEP_NAME, workflowStepName) + .field(RESOURCE_ID_FIELD, resourceId); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(workflowStepName); + out.writeString(resourceId); + } + + /** + * Gets the resource id. + * + * @return the resourceId. + */ + public String resourceId() { + return resourceId; + } + + /** + * Gets the workflow step name associated to the created resource + * + * @return the workflowStepName. + */ + public String workflowStepName() { + return workflowStepName; + } + + /** + * Parse raw JSON content into a ResourceCreated instance. + * + * @param parser JSON based content parser + * @return the parsed ResourceCreated instance + * @throws IOException if content can't be parsed correctly + */ + public static ResourceCreated parse(XContentParser parser) throws IOException { + String workflowStepName = null; + String resourceId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case WORKFLOW_STEP_NAME: + workflowStepName = parser.text(); + break; + case RESOURCE_ID_FIELD: + resourceId = parser.text(); + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + } + } + if (workflowStepName == null || resourceId == null) { + throw new IOException("A ResourceCreated object requires both a workflowStepName and resourceId."); + } + return new ResourceCreated(workflowStepName, resourceId); + } + + @Override + public String toString() { + return "resources_Created [resource_name=" + workflowStepName + ", id=" + resourceId + "]"; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index 58178c2f4..bfb40b696 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -13,6 +13,8 @@ import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.yaml.YamlXContent; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -39,7 +41,7 @@ /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. */ -public class Template implements ToXContentObject { +public class Template implements ToXContentObject, Writeable { private String name; private String description; @@ -240,6 +242,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } + // TODO: fix writeable when implementing get workflow API + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(name); + output.writeOptionalString(description); + output.writeString(useCase); + output.writeVersion(templateVersion); + // output.writeList((List) compatibilityVersion); + output.writeMapWithConsistentOrder(workflows); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + } + /** * Parse raw json content into a Template instance. * diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index f51520585..b8086aa4c 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -9,14 +9,23 @@ package org.opensearch.flowframework.model; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -35,17 +44,17 @@ * The WorkflowState is used to store all additional information regarding a workflow that isn't part of the * global context. */ -public class WorkflowState implements ToXContentObject { +public class WorkflowState implements ToXContentObject, Writeable { private String workflowId; private String error; private String state; - // TODO: Tranisiton the provisioning progress from a string to detailed array of objects + // TODO: Transition the provisioning progress from a string to detailed array of objects private String provisioningProgress; private Instant provisionStartTime; private Instant provisionEndTime; private User user; private Map userOutputs; - private Map resourcesCreated; + private List resourcesCreated; /** * Instantiate the object representing the workflow state @@ -69,7 +78,7 @@ public WorkflowState( Instant provisionEndTime, User user, Map userOutputs, - Map resourcesCreated + List resourcesCreated ) { this.workflowId = workflowId; this.error = error; @@ -79,11 +88,29 @@ public WorkflowState( this.provisionEndTime = provisionEndTime; this.user = user; this.userOutputs = Map.copyOf(userOutputs); - this.resourcesCreated = Map.copyOf(resourcesCreated); + this.resourcesCreated = List.copyOf(resourcesCreated); } private WorkflowState() {} + /** + * Instatiates a new WorkflowState from an input stream + * @param input the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public WorkflowState(StreamInput input) throws IOException { + this.workflowId = input.readString(); + this.error = input.readOptionalString(); + this.state = input.readOptionalString(); + this.provisioningProgress = input.readOptionalString(); + this.provisionStartTime = input.readOptionalInstant(); + this.provisionEndTime = input.readOptionalInstant(); + // TODO: fix error: cannot access Response issue when integrating with access control + // this.user = input.readBoolean() ? new User(input) : null; + this.userOutputs = input.readBoolean() ? input.readMap() : null; + this.resourcesCreated = input.readList(ResourceCreated::new); + } + /** * Constructs a builder object for workflowState * @return Builder Object @@ -104,7 +131,7 @@ public static class Builder { private Instant provisionEndTime = null; private User user = null; private Map userOutputs = null; - private Map resourcesCreated = null; + private List resourcesCreated = null; /** * Empty Constructor for the Builder object @@ -196,8 +223,8 @@ public Builder userOutputs(Map userOutputs) { * @param resourcesCreated resourcesCreated * @return the Builder object */ - public Builder resourcesCreated(Map resourcesCreated) { - this.userOutputs = resourcesCreated; + public Builder resourcesCreated(List resourcesCreated) { + this.resourcesCreated = resourcesCreated; return this; } @@ -248,11 +275,35 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(USER_OUTPUTS_FIELD, userOutputs); } if (resourcesCreated != null && !resourcesCreated.isEmpty()) { - xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated); + xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated.toArray()); } return xContentBuilder.endObject(); } + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(workflowId); + output.writeOptionalString(error); + output.writeOptionalString(state); + output.writeOptionalString(provisioningProgress); + output.writeOptionalInstant(provisionStartTime); + output.writeOptionalInstant(provisionEndTime); + + if (user != null) { + user.writeTo(output); + } else { + output.writeBoolean(false); + } + + if (userOutputs != null) { + output.writeBoolean(true); + output.writeMap(userOutputs); + } else { + output.writeBoolean(false); + } + output.writeList(resourcesCreated); + } + /** * Parse raw json content into a Template instance. * @@ -269,7 +320,7 @@ public static WorkflowState parse(XContentParser parser) throws IOException { Instant provisionEndTime = null; User user = null; Map userOutputs = new HashMap<>(); - Map resourcesCreated = new HashMap<>(); + List resourcesCreated = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -313,23 +364,17 @@ public static WorkflowState parse(XContentParser parser) throws IOException { } } break; - case RESOURCES_CREATED_FIELD: - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String resourcesCreatedField = parser.currentName(); - switch (parser.nextToken()) { - case VALUE_STRING: - resourcesCreated.put(resourcesCreatedField, parser.text()); - break; - case START_OBJECT: - resourcesCreated.put(resourcesCreatedField, parseStringToStringMap(parser)); - break; - default: - throw new IOException( - "Unable to parse field [" + resourcesCreatedField + "] in a resources_created object." - ); + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + resourcesCreated.add(ResourceCreated.parse(parser)); } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new FlowFrameworkException("Error parsing newly created resources", RestStatus.INTERNAL_SERVER_ERROR); + } + throw e; } break; default: @@ -361,7 +406,7 @@ public String getWorkflowId() { * @return the error */ public String getError() { - return workflowId; + return error; } /** @@ -416,7 +461,7 @@ public Map userOutputs() { * A map of all the resources created * @return the resources created */ - public Map resourcesCreated() { + public List resourcesCreated() { return resourcesCreated; } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 7bd5fbfbe..5e8509373 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -47,7 +47,6 @@ public class RestCreateWorkflowAction extends AbstractWorkflowAction { /** * Instantiates a new RestCreateWorkflowAction - * * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled * @param settings Environment settings * @param clusterService clusterService diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java new file mode 100644 index 000000000..6d9d5e3b5 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.GetWorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to get a workflow status + */ +public class RestGetWorkflowAction extends BaseRestHandler { + + private static final String GET_WORKFLOW_ACTION = "get_workflow"; + private static final Logger logger = LogManager.getLogger(RestGetWorkflowAction.class); + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestGetWorkflowAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestGetWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return GET_WORKFLOW_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + boolean all = request.paramAsBoolean("all", false); + GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest(workflowId, all); + return channel -> client.execute(GetWorkflowAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } + + @Override + public List routes() { + return ImmutableList.of( + // Provision workflow from indexed use case template + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java index ff626ffd0..91fc37d92 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java @@ -10,7 +10,7 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestCreateWorkflowAction @@ -18,7 +18,7 @@ public class CreateWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/create"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/create"; /** An instance of this action */ public static final CreateWorkflowAction INSTANCE = new CreateWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a782740c4..a77e7dd79 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -195,26 +195,30 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener internalListener) { - QueryBuilder query = QueryBuilders.matchAllQuery(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut); + if (!flowFrameworkIndicesHandler.doesIndexExist(CommonValue.GLOBAL_CONTEXT_INDEX)) { + internalListener.onResponse(true); + } else { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut); - SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder); + SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { - if (searchResponse.getHits().getTotalHits().value >= maxWorkflow) { - internalListener.onResponse(false); - } else { - internalListener.onResponse(true); - } - }, exception -> { - logger.error("Unable to fetch the workflows {}", exception); - internalListener.onFailure(new FlowFrameworkException("Unable to fetch the workflows", RestStatus.BAD_REQUEST)); - })); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.getHits().getTotalHits().value >= maxWorkflow) { + internalListener.onResponse(false); + } else { + internalListener.onResponse(true); + } + }, exception -> { + logger.error("Unable to fetch the workflows {}", exception); + internalListener.onFailure(new FlowFrameworkException("Unable to fetch the workflows", RestStatus.BAD_REQUEST)); + })); + } } private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { - List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow); + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); workflowProcessSorter.validateGraph(sortedNodes); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java new file mode 100644 index 000000000..6abee3867 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowAction + */ +public class GetWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/get"; + /** An instance of this action */ + public static final GetWorkflowAction INSTANCE = new GetWorkflowAction(); + + private GetWorkflowAction() { + super(NAME, GetWorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java new file mode 100644 index 000000000..c7594eb77 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Transport Request to get a workflow or workflow status + */ +public class GetWorkflowRequest extends ActionRequest { + + /** + * The documentId of the workflow entry within the Global Context index + */ + @Nullable + private String workflowId; + + /** + * The all parameter on the get request + */ + private boolean all; + + /** + * Instantiates a new GetWorkflowRequest + * @param workflowId the documentId of the workflow + * @param all whether the get request is looking for all fields in status + */ + public GetWorkflowRequest(@Nullable String workflowId, boolean all) { + this.workflowId = workflowId; + this.all = all; + } + + /** + * Instantiates a new GetWorkflowRequest request + * @param in The input stream to read from + * @throws IOException If the stream cannot be read properly + */ + public GetWorkflowRequest(StreamInput in) throws IOException { + super(in); + this.workflowId = in.readString(); + this.all = in.readBoolean(); + } + + /** + * Gets the workflow Id of the request + * @return the workflow Id + */ + @Nullable + public String getWorkflowId() { + return this.workflowId; + } + + /** + * Gets the value of the all parameter + * @return whether the all parameter was present or not in request + */ + public boolean getAll() { + return this.all; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(workflowId); + out.writeBoolean(all); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java new file mode 100644 index 000000000..b2c9bb884 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.IOException; + +/** + * Transport Response from getting a workflow status + */ +public class GetWorkflowResponse extends ActionResponse implements ToXContentObject { + + public WorkflowState workflowState; + public boolean allStatus; + + /** + * Instantiates a new GetWorkflowResponse from an input stream + * @param in the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public GetWorkflowResponse(StreamInput in) throws IOException { + super(in); + workflowState = new WorkflowState(in); + allStatus = in.readBoolean(); + } + + /** + * Instatiates a new GetWorkflowResponse from an input stream + * @param workflowState the workflow state object + * @param allStatus whether to return all fields in state index + */ + public GetWorkflowResponse(WorkflowState workflowState, boolean allStatus) { + if (allStatus) { + this.workflowState = workflowState; + } else { + this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + workflowState.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return workflowState.toXContent(xContentBuilder, params); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java new file mode 100644 index 000000000..f3bc1dd9e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +//TODO: Currently we only get the workflow status but we should change to be able to get the +// full template as well +/** + * Transport Action to get a specific workflow. Currently, we only support the action with _status + * in the API path but will add the ability to get the workflow and not just the status in the future + */ +public class GetWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Intantiates a new CreateWorkflowTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param client The client used to make the request to OS + * @param xContentRegistry contentRegister to parse get response + */ + @Inject + public GetWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(GetWorkflowAction.NAME, transportService, actionFilters, GetWorkflowRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, GetWorkflowRequest request, ActionListener listener) { + String workflowId = request.getWorkflowId(); + User user = ParseUtils.getUserContext(client); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll())); + } catch (Exception e) { + logger.error("Failed to parse workflowState" + r.getId(), e); + listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST)); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } else { + logger.error("Failed to get workflow status of: " + workflowId, e); + listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND)); + } + }), () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to get workflow: " + workflowId, e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java index 022e73488..30f9f1437 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java @@ -10,14 +10,14 @@ import org.opensearch.action.ActionType; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestProvisionWorkflowAction */ public class ProvisionWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/provision"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/provision"; /** An instance of this action */ public static final ProvisionWorkflowAction INSTANCE = new ProvisionWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 4210effb0..1f19a4d04 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -34,16 +34,20 @@ import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; /** @@ -106,10 +110,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); workflowProcessSorter.validateGraph(provisionProcessSequence); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( @@ -120,7 +123,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.info("updated workflow {} state to PROVISIONING", request.getWorkflowId()); @@ -129,7 +134,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { if (exception instanceof FlowFrameworkException) { @@ -150,31 +155,22 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence) { - // TODO : Update Action listener type to State index Request - ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { - logger.info("Provisioning completed successuflly for workflow {}", workflowId); - - // TODO : Create State index request to update STATE entry status to READY - }, exception -> { - logger.error("Provisioning failed for workflow {} : {}", workflowId, exception); - - // TODO : Create State index request to update STATE entry status to FAILED - }); + private void executeWorkflowAsync(String workflowId, List workflowSequence, ActionListener listener) { try { - threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); + threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); } catch (Exception exception) { - provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } } /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute - * @param workflowListener The listener that updates the status of a workflow execution + * @param workflowId The workflowId associated with the workflow that is executing */ - private void executeWorkflow(List workflowSequence, ActionListener workflowListener) { + private void executeWorkflow(List workflowSequence, String workflowId) { try { List> workflowFutureList = new ArrayList<>(); @@ -199,13 +195,39 @@ private void executeWorkflow(List workflowSequence, ActionListener< // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally workflowFutureList.forEach(CompletableFuture::join); - // TODO : Create State Index request with provisioning state, start time, end time, etc, pending implementation. String for now - workflowListener.onResponse("READY"); - - } catch (IllegalArgumentException e) { - workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); + logger.info("Provisioning completed successfully for workflow {}", workflowId); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.COMPLETED, + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.DONE, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); } catch (Exception ex) { - workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), ExceptionsHelper.status(ex))); + logger.error("Provisioning failed for workflow {} : {}", workflowId, ex); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + ImmutableMap.of( + STATE_FIELD, + State.FAILED, + ERROR_FIELD, + ex.getMessage(), // TODO: potentially improve the error message here + PROVISIONING_PROGRESS_FIELD, + ProvisioningProgress.FAILED, + PROVISION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage()); }) + ); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowAction.java index 7f29174d9..04f4be2ba 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowAction.java @@ -11,7 +11,7 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; /** * External Action for public facing RestSearchWorkflowAction @@ -19,7 +19,7 @@ public class SearchWorkflowAction extends ActionType { /** The name of this action */ - public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/search"; + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/search"; /** An instance of this action */ public static final SearchWorkflowAction INSTANCE = new SearchWorkflowAction(); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 338f23cdc..0f725f687 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -12,9 +12,12 @@ import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -113,4 +116,17 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + /** + * Creates a XContentParser from a given Registry + * + * @param xContentRegistry main registry for serializable content + * @param bytesReference given bytes to be parsed + * @return bytesReference of {@link java.time.Instant} + * @throws IOException IOException if content can't be parsed correctly + */ + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 533d82c1e..41bd71489 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,14 +11,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import java.io.IOException; import java.security.AccessController; @@ -41,6 +48,7 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; /** * Step to create a connector for a remote model @@ -50,17 +58,21 @@ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); private MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; static final String NAME = "create_connector"; /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateConnectorStep(MachineLearningNodeClient mlClient) { + public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } + // TODO: need to add retry conflicts here @Override public CompletableFuture execute(List data) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); @@ -69,11 +81,44 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - logger.info("Created connector successfully"); - // TODO Add the response to Global Context createConnectorFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), + data.get(0).getWorkflowId() + ) ); + try { + logger.info("Created connector successfully"); + String workflowId = data.get(0).getWorkflowId(); + String workflowStepName = getName(); + ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + + // The script to append a new object to the resources_created array + Script script = new Script( + ScriptType.INLINE, + "painless", + "ctx._source.resources_created.add(params.newResource)", + Collections.singletonMap("newResource", newResource) + ); + + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( + WORKFLOW_STATE_INDEX, + workflowId, + script, + ActionListener.wrap(updateResponse -> { + logger.info("updated resources created of {}", workflowId); + }, exception -> { + createConnectorFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + logger.error("Failed to update workflow state with newly created resource: {}", exception); + }) + ); + } catch (IOException e) { + logger.error("Failed to parse new created resource", e); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 5fe47b2b0..f3a82b26c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -61,7 +61,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index b8cc83651..a63a800fd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -125,7 +125,9 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()))); + createIngestPipelineFuture.complete( + new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index ba22f3682..8ce89176c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -50,7 +50,10 @@ public CompletableFuture execute(List data) { public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); deployModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus()))) + new WorkflowData( + Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), + data.get(0).getWorkflowId() + ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index 893f34a0d..ac84aaaa0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -56,7 +56,8 @@ public CompletableFuture execute(List data) { logger.info("ML Task retrieval successful"); getMLTaskFuture.complete( new WorkflowData( - Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())) + Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), + data.get(0).getWorkflowId() ) ); }, exception -> { diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 35a3bdfff..89c15c445 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.ofEntries( Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 17dd0b068..ad6cbff8f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 4dedc8bf2..d91cfc0e8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 35ffb7e75..4f62885e9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.common.Nullable; + import java.util.Collections; import java.util.Map; @@ -24,26 +26,32 @@ public class WorkflowData { private final Map content; private final Map params; + @Nullable + private String workflowId; + private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap()); + this(Collections.emptyMap(), Collections.emptyMap(), ""); } /** * Instantiate this object with content and empty params. * @param content The content map + * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content) { - this(content, Collections.emptyMap()); + public WorkflowData(Map content, @Nullable String workflowId) { + this(content, Collections.emptyMap(), workflowId); } /** * Instantiate this object with content and params. * @param content The content map * @param params The params map + * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content, Map params) { + public WorkflowData(Map content, Map params, @Nullable String workflowId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); + this.workflowId = workflowId; } /** @@ -62,4 +70,13 @@ public Map getContent() { public Map getParams() { return this.params; }; + + /** + * Returns the workflowId associated with this workflow. + * @return the workflowId of this data. + */ + @Nullable + public String getWorkflowId() { + return this.workflowId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 10a038cbb..cae91bcbc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -61,16 +61,17 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool /** * Sort a workflow into a topologically sorted list of process nodes. * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors + * @param workflowId The workflowId associated with the step * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public List sortProcessNodes(Workflow workflow) { + public List sortProcessNodes(Workflow workflow, String workflowId) { List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams()); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index bdbe63caa..1f5545cdf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -30,5 +30,4 @@ public interface WorkflowStep { * @return the name of this workflow step. */ String getName(); - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index af83f0ad9..fc65286e7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; @@ -23,6 +24,7 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -30,20 +32,31 @@ public class WorkflowStepFactory { * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use * @param mlClient Machine Learning client to perform ml operations + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - - public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { - populateMap(clusterService, client, mlClient); + public WorkflowStepFactory( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler); } - private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + private void populateMap( + ClusterService clusterService, + Client client, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(mlClient)); stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient)); } diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 86fbeef6e..21df5ccd6 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,7 +31,15 @@ "type": "object" }, "resources_created": { - "type": "object" + "type": "nested", + "properties": { + "workflow_step_name": { + "type": "keyword" + }, + "resource_id": { + "type": "keyword" + } + } }, "ui_metadata": { "type": "object", diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index d0b35f9d6..c8fcdf69e 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -80,8 +80,8 @@ public void testPlugin() throws IOException { 3, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(3, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(3, ffp.getActions().size()); + assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(4, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(3, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 9c3f8a07e..07221297a 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -13,8 +13,14 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import java.io.IOException; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -40,4 +46,12 @@ public static ClusterSettings clusterSetting(Settings settings, Setting... se ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); return clusterSettings; } + + public static XContentBuilder builder() throws IOException { + return XContentBuilder.builder(XContentType.JSON.xContent()); + } + + public static Map XContentBuilderToMap(XContentBuilder builder) { + return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); + } } diff --git a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java new file mode 100644 index 000000000..a8536dd43 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class ResourceCreatedTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testParseFeature() throws IOException { + ResourceCreated ResourceCreated = new ResourceCreated("A", "B"); + assertEquals(ResourceCreated.workflowStepName(), "A"); + assertEquals(ResourceCreated.resourceId(), "B"); + + String expectedJson = "{\"workflow_step_name\":\"A\",\"resource_id\":\"B\"}"; + String json = TemplateTestJsonUtil.parseToJson(ResourceCreated); + assertEquals(expectedJson, json); + + ResourceCreated ResourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals("A", ResourceCreatedTwo.workflowStepName()); + assertEquals("B", ResourceCreatedTwo.resourceId()); + } + + public void testExceptions() throws IOException { + String badJson = "{\"wrong\":\"A\",\"resource_id\":\"B\"}"; + IOException e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(badJson))); + assertEquals("Unable to parse field [wrong] in a resources_created object.", e.getMessage()); + + String missingJson = "{\"resource_id\":\"B\"}"; + e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); + assertEquals("A ResourceCreated object requires both a workflowStepName and resourceId.", e.getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 6c474a11e..28e7e0585 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -13,6 +13,7 @@ import org.opensearch.client.ClusterAdminClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; @@ -67,8 +68,9 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { when(client.admin()).thenReturn(adminClient); when(adminClient.cluster()).thenReturn(clusterAdminClient); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); // Read in workflow-steps.json WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json"); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java new file mode 100644 index 000000000..0f6ddab59 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestGetWorkflowActionTests extends OpenSearchTestCase { + private RestGetWorkflowAction restGetWorkflowAction; + private String getPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restGetWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testConstructor() { + RestGetWorkflowAction getWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); + assertNotNull(getWorkflowAction); + } + + public void testRestGetWorkflowActionName() { + String name = restGetWorkflowAction.getName(); + assertEquals("get_workflow", name); + } + + public void testRestGetWorkflowActionRoutes() { + List routes = restGetWorkflowAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.GET, routes.get(0).getMethod()); + assertEquals(this.getPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 5492ad822..22d831f2b 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -34,6 +34,7 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.mockito.ArgumentCaptor; @@ -198,6 +199,20 @@ public void testMaxWorkflow() { assertEquals(("Maximum workflows limit reached 1000"), exceptionCaptor.getValue().getMessage()); } + public void testMaxWorkflowWithNoIndex() { + @SuppressWarnings("unchecked") + ActionListener listener = new ActionListener() { + @Override + public void onResponse(Boolean booleanResponse) { + assertTrue(booleanResponse); + } + + @Override + public void onFailure(Exception e) {} + }; + createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), 10, listener); + } + public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java new file mode 100644 index 000000000..c3991783d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.Assert; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.mockito.Mockito; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GetWorkflowTransportActionTests extends OpenSearchTestCase { + + private GetWorkflowTransportAction getWorkflowTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Client client; + private ThreadPool threadPool; + private ThreadContext threadContext; + private ActionListener response; + private Task task; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.threadPool = mock(ThreadPool.class); + this.getWorkflowTransportAction = new GetWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + xContentRegistry() + ); + task = Mockito.mock(Task.class); + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + + response = new ActionListener() { + @Override + public void onResponse(GetWorkflowResponse getResponse) { + assertTrue(true); + } + + @Override + public void onFailure(Exception e) {} + }; + + } + + public void testGetTransportAction() throws IOException { + GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest("1234", false); + getWorkflowTransportAction.doExecute(task, getWorkflowRequest, response); + } + + public void testGetAction() { + Assert.assertNotNull(GetWorkflowAction.INSTANCE.name()); + Assert.assertEquals(GetWorkflowAction.INSTANCE.name(), GetWorkflowAction.NAME); + } + + public void testGetAnomalyDetectorRequest() throws IOException { + GetWorkflowRequest request = new GetWorkflowRequest("1234", false); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetWorkflowRequest newRequest = new GetWorkflowRequest(input); + Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); + Assert.assertEquals(request.getAll(), newRequest.getAll()); + Assert.assertNull(newRequest.validate()); + } + + public void testGetAnomalyDetectorResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + String workflowId = randomAlphaOfLength(5); + WorkflowState workFlowState = new WorkflowState( + workflowId, + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + GetWorkflowResponse response = new GetWorkflowResponse(workFlowState, false); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetWorkflowResponse newResponse = new GetWorkflowResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertEquals(map.get("state"), workFlowState.getState()); + Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 63855f7bd..a05a3927e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -12,6 +12,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -30,6 +31,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; public class CreateConnectorStepTests extends OpenSearchTestCase { @@ -38,10 +40,12 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); Map[] actions = new Map[] { @@ -66,14 +70,15 @@ public void setUp() throws Exception { Map.entry(CommonValue.PARAMETERS_FIELD, params), Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) - ) + ), + "test-id" ); } public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { String connectorId = "connect"; - CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -95,7 +100,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr } public void testCreateConnectorFailure() throws IOException { - CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 7a4db70a6..67cb6cb9b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 039b0384f..194c80eb0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -50,11 +50,12 @@ public void setUp() throws Exception { Map.entry("model_id", "model_id"), Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") - ) + ), + "test-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId"))); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -113,7 +114,8 @@ public void testMissingData() throws InterruptedException { Map.entry("description", "some description"), Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") - ) + ), + "test-id" ); CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 4cdfaebae..fd856b945 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); MockitoAnnotations.openMocks(this); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index 3a83b1fdd..f5f5f7e7d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -48,7 +48,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); this.getMLTaskStep = new GetMLTaskStep(mlNodeClient); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test"))); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); } public void testGetMLTaskSuccess() throws Exception { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index 8868b628e..f763c8005 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -53,7 +53,8 @@ public void setUp() throws Exception { Map.entry("backend_roles", ImmutableList.of("role-1")), Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 0cac95b49..6aae139e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -58,7 +58,7 @@ public void testNode() throws InterruptedException, ExecutionException { @Override public CompletableFuture execute(List data) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"))); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); return f; } @@ -68,7 +68,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -77,6 +77,7 @@ public String getName() { assertEquals("test", nodeA.workflowStep().getName()); assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); + assertEquals("test-id", nodeA.input().getWorkflowId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d41096624..bd40c50ad 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -69,7 +69,8 @@ public void setUp() throws Exception { Map.entry("embedding_dimension", "384"), Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index ca9d5e7a5..e60707f67 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -54,7 +54,8 @@ public void setUp() throws Exception { Map.entry("name", "xyz"), Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index e2464dace..8a4a1fda9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,13 +26,14 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent); + WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); + assertEquals("test-id-123", contentAndParams.getWorkflowId()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 3b1d55a69..99a6f6105 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -50,7 +51,7 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase { private static List parseToNodes(String json) throws IOException { XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); Workflow w = Workflow.parse(parser); - return workflowProcessSorter.sortProcessNodes(w); + return workflowProcessSorter.sortProcessNodes(w, "123"); } // Wrap parser into string list @@ -67,11 +68,12 @@ public static void setup() { ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } @@ -252,7 +254,7 @@ public void testSuccessfulGraphValidation() throws Exception { Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); workflowProcessSorter.validateGraph(sortedProcessNodes); } @@ -274,7 +276,7 @@ public void testFailedGraphValidation() { WorkflowEdge edge = new WorkflowEdge(registerModel.id(), deployModel.id()); Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); FlowFrameworkException ex = expectThrows( FlowFrameworkException.class, () -> workflowProcessSorter.validateGraph(sortedProcessNodes)