From 9bf239eada7015791eb390e57a42308ce4678db5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 10 Jul 2024 01:15:49 +0000 Subject: [PATCH] Support editing of certain workflow fields on a provisioned workflow (#757) * Support editing of certain workflow fields on a provisioned workflow Signed-off-by: Daniel Widdis * Add integ test Signed-off-by: Daniel Widdis * Address review comments Signed-off-by: Daniel Widdis * Refactor field update method to Template class Signed-off-by: Daniel Widdis * Update tests to ensure update timestamp increments Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis (cherry picked from commit 7d45f92d088de9f8338b66e29ca11cd89c691b5e) Signed-off-by: github-actions[bot] --- CHANGELOG.md | 2 + .../flowframework/common/CommonValue.java | 4 +- .../flowframework/model/Template.java | 86 ++++++++++++++++--- .../rest/RestCreateWorkflowAction.java | 32 +++++-- .../CreateWorkflowTransportAction.java | 66 ++++++++------ .../transport/WorkflowRequest.java | 41 +++++++-- .../FlowFrameworkRestTestCase.java | 19 ++++ .../flowframework/model/TemplateTests.java | 58 +++++++++++++ .../rest/FlowFrameworkRestApiIT.java | 38 ++++++++ .../rest/RestCreateWorkflowActionTests.java | 78 ++++++++++++++++- .../CreateWorkflowTransportActionTests.java | 81 ++++++++++++++++- .../WorkflowRequestResponseTests.java | 86 ++++++++++++++----- 12 files changed, 515 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffe9b399b..e4a79d856 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.14...2.x) ### Features +- Support editing of certain workflow fields on a provisioned workflow ([#757](https://github.com/opensearch-project/flow-framework/pull/757)) + ### Enhancements - Register system index descriptors through SystemIndexPlugin.getSystemIndexDescriptors ([#750](https://github.com/opensearch-project/flow-framework/pull/750)) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 10a23357a..2f9b0764e 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -68,8 +68,10 @@ private CommonValue() {} public static final String WORKFLOW_ID = "workflow_id"; /** Field name for template validation, the flag to indicate if validation is necessary */ public static final String VALIDATION = "validation"; - /** The field name for provision workflow within a use case template*/ + /** The param name for provision workflow in create API */ public static final String PROVISION_WORKFLOW = "provision"; + /** The param name for update workflow field in create API */ + public static final String UPDATE_WORKFLOW_FIELDS = "update_fields"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ public static final String WORKFLOW_STEP = "workflow_step"; /** The param name for default use case, used by the create workflow API */ diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index 71632d003..a99b87c4b 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.model; +import org.apache.logging.log4j.util.Strings; import org.opensearch.Version; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.json.JsonXContent; @@ -19,6 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template.Builder; import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; @@ -29,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; @@ -53,6 +56,14 @@ public class Template implements ToXContentObject { public static final String TEMPLATE_FIELD = "template"; /** The template field name for template use case */ public static final String USE_CASE_FIELD = "use_case"; + /** Fields which may be updated in the template even if provisioned */ + public static final Set UPDATE_FIELD_ALLOWLIST = Set.of( + NAME_FIELD, + DESCRIPTION_FIELD, + USE_CASE_FIELD, + VERSION_FIELD, + UI_METADATA_FIELD + ); private final String name; private final String description; @@ -77,9 +88,9 @@ public class Template implements ToXContentObject { * @param workflows Workflow graph definitions corresponding to the defined operations. * @param uiMetadata The UI metadata related to the given workflow * @param user The user extracted from the thread context from the request - * @param createdTime Created time in milliseconds since the epoch - * @param lastUpdatedTime Last Updated time in milliseconds since the epoch - * @param lastProvisionedTime Last Provisioned time in milliseconds since the epoch + * @param createdTime Created time as an Instant + * @param lastUpdatedTime Last Updated time as an Instant + * @param lastProvisionedTime Last Provisioned time as an Instant */ public Template( String name, @@ -286,9 +297,9 @@ public Template build() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); - xContentBuilder.field(NAME_FIELD, this.name); - xContentBuilder.field(DESCRIPTION_FIELD, this.description); - xContentBuilder.field(USE_CASE_FIELD, this.useCase); + xContentBuilder.field(NAME_FIELD, this.name.trim()); + xContentBuilder.field(DESCRIPTION_FIELD, this.description == null ? "" : this.description.trim()); + xContentBuilder.field(USE_CASE_FIELD, this.useCase == null ? "" : this.useCase.trim()); if (this.templateVersion != null || !this.compatibilityVersion.isEmpty()) { xContentBuilder.startObject(VERSION_FIELD); @@ -334,6 +345,35 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } + /** + * Merges two templates by updating the fields from an existing template with the (non-null) fields of another one. + * @param existingTemplate An existing complete template. + * @param templateWithNewFields A template containing only fields to update. The fields must correspond to the field names in {@link #UPDATE_FIELD_ALLOWLIST}. + * @return the updated template. + */ + public static Template updateExistingTemplate(Template existingTemplate, Template templateWithNewFields) { + Builder builder = new Template.Builder(existingTemplate).lastUpdatedTime(Instant.now()); + if (templateWithNewFields.name() != null) { + builder.name(templateWithNewFields.name()); + } + if (!Strings.isBlank(templateWithNewFields.description())) { + builder.description(templateWithNewFields.description()); + } + if (!Strings.isBlank(templateWithNewFields.useCase())) { + builder.useCase(templateWithNewFields.useCase()); + } + if (templateWithNewFields.templateVersion() != null) { + builder.templateVersion(templateWithNewFields.templateVersion()); + } + if (!templateWithNewFields.compatibilityVersion().isEmpty()) { + builder.compatibilityVersion(templateWithNewFields.compatibilityVersion()); + } + if (templateWithNewFields.getUiMetadata() != null) { + builder.uiMetadata(templateWithNewFields.getUiMetadata()); + } + return builder.build(); + } + /** * Parse raw xContent into a Template instance. * @@ -342,9 +382,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws * @throws IOException if content can't be parsed correctly */ public static Template parse(XContentParser parser) throws IOException { + return parse(parser, false); + } + + /** + * Parse raw xContent into a Template instance. + * + * @param parser xContent based content parser + * @param fieldUpdate if set true, will be used for updating an existing template + * @return an instance of the template + * @throws IOException if content can't be parsed correctly + */ + public static Template parse(XContentParser parser, boolean fieldUpdate) throws IOException { String name = null; - String description = ""; - String useCase = ""; + String description = null; + String useCase = null; Version templateVersion = null; List compatibilityVersion = new ArrayList<>(); Map workflows = new HashMap<>(); @@ -357,6 +409,12 @@ public static Template parse(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); + if (fieldUpdate && !UPDATE_FIELD_ALLOWLIST.contains(fieldName)) { + throw new FlowFrameworkException( + "You can not update the field [" + fieldName + "] without updating the whole template.", + RestStatus.BAD_REQUEST + ); + } parser.nextToken(); switch (fieldName) { case NAME_FIELD: @@ -421,8 +479,16 @@ public static Template parse(XContentParser parser) throws IOException { ); } } - if (name == null) { - throw new FlowFrameworkException("A template object requires a name.", RestStatus.BAD_REQUEST); + if (!fieldUpdate) { + if (name == null) { + throw new FlowFrameworkException("A template object requires a name.", RestStatus.BAD_REQUEST); + } + if (description == null) { + description = ""; + } + if (useCase == null) { + useCase = ""; + } } return new Builder().name(name) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 5db17d2b7..bb604e8d6 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -39,6 +39,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; @@ -83,6 +84,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); + boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); String useCase = request.param(USE_CASE); // If provisioning, consume all other params and pass to provision transport action Map params = provision @@ -117,11 +119,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); } + if (provision && updateFields) { + // Consume params and content so custom exception is processed + params.keySet().stream().forEach(request::param); + request.content(); + FlowFrameworkException ffe = new FlowFrameworkException( + "You can not use both the " + PROVISION_WORKFLOW + " and " + UPDATE_WORKFLOW_FIELDS + " parameters in the same request.", + RestStatus.BAD_REQUEST + ); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } try { - Template template; Map useCaseDefaultsMap = Collections.emptyMap(); if (useCase != null) { + // Reconstruct the template from a substitution-ready use case String useCaseTemplateFileInStringFormat = ParseUtils.resourceToString( "/" + DefaultUseCases.getSubstitutionReadyFileByUseCaseName(useCase) ); @@ -178,21 +192,25 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli null, useCaseDefaultsMap ); - XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson); - template = Template.parse(parserTestJson); - + XContentParser useCaseParser = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat); + ensureExpectedToken(XContentParser.Token.START_OBJECT, useCaseParser.currentToken(), useCaseParser); + template = Template.parse(useCaseParser); } else { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - template = Template.parse(parser); + template = Template.parse(parser, updateFields); + } + + // If not provisioning, params map is empty. Use it to pass updateFields flag to WorkflowRequest + if (updateFields) { + params = Map.of(UPDATE_WORKFLOW_FIELDS, "true"); } WorkflowRequest workflowRequest = new WorkflowRequest( workflowId, template, validation, - provision, + provision || updateFields, params, useCase, useCaseDefaultsMap diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 720fa1ea6..0732fc106 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -233,6 +233,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - request.getWorkflowId(), - Map.ofEntries( - Map.entry(STATE_FIELD, State.NOT_STARTED), - Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) - ), - ActionListener.wrap(updateResponse -> { - logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.NOT_STARTED.name()); - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - }, exception -> { - String errorMessage = "Failed to update workflow " + request.getWorkflowId() + " in template index"; - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) + // Ignore state index if updating fields + if (!isFieldUpdate) { + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + request.getWorkflowId(), + Map.ofEntries( + Map.entry(STATE_FIELD, State.NOT_STARTED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) + ), + ActionListener.wrap(updateResponse -> { + logger.info( + "updated workflow {} state to {}", + request.getWorkflowId(), + State.NOT_STARTED.name() ); - } - }) - ); + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + }, exception -> { + String errorMessage = "Failed to update workflow " + + request.getWorkflowId() + + " in template index"; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) + ); + } + }) + ); + } else { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } }, exception -> { String errorMessage = "Failed to update use case template " + request.getWorkflowId(); logger.error(errorMessage, exception); @@ -278,7 +291,8 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener params, String useCase, Map defaultParams @@ -111,11 +118,12 @@ public WorkflowRequest( this.workflowId = workflowId; this.template = template; this.validation = validation; - this.provision = provision; - if (!provision && !params.isEmpty()) { + this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS); + this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS)); + if (!this.provision && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k))) { throw new IllegalArgumentException("Params may only be included when provisioning."); } - this.params = params; + this.params = this.updateFields ? Collections.emptyMap() : params; this.useCase = useCase; this.defaultParams = defaultParams; } @@ -131,8 +139,13 @@ public WorkflowRequest(StreamInput in) throws IOException { String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); this.validation = in.readStringArray(); - this.provision = in.readBoolean(); - this.params = this.provision ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap(); + boolean provisionOrUpdate = in.readBoolean(); + this.params = provisionOrUpdate ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap(); + this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS); + this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS)); + if (this.updateFields) { + this.params = Collections.emptyMap(); + } } /** @@ -169,6 +182,14 @@ public boolean isProvision() { return this.provision; } + /** + * Gets the update fields flag + * @return the update fields boolean + */ + public boolean isUpdateFields() { + return this.updateFields; + } + /** * Gets the params map * @return the params map @@ -199,9 +220,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); out.writeStringArray(validation); - out.writeBoolean(provision); + out.writeBoolean(provision || updateFields); if (provision) { out.writeMap(params, StreamOutput::writeString, StreamOutput::writeString); + } else if (updateFields) { + out.writeMap(Map.of(UPDATE_WORKFLOW_FIELDS, "true"), StreamOutput::writeString, StreamOutput::writeString); } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index e190a42ca..87047de2c 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -425,6 +425,25 @@ protected Response updateWorkflow(RestClient client, String workflowId, Template ); } + /** + * Helper method to invoke the Update Workflow API + * @param client the rest client + * @param workflowId the document id + * @param templateFields the JSON containing some template fields + * @throws Exception if the request fails + * @return a rest response + */ + protected Response updateWorkflowWithFields(RestClient client, String workflowId, String templateFields) throws Exception { + return TestHelpers.makeRequest( + client, + "PUT", + String.format(Locale.ROOT, "%s/%s?update_fields=true", WORKFLOW_URI, workflowId), + Collections.emptyMap(), + templateFields, + null + ); + } + /** * Helper method to invoke the Provision Workflow Rest Action * @param client the rest client diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java index 63dbf31f6..0ce66c2de 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -9,7 +9,11 @@ package org.opensearch.flowframework.model; import org.opensearch.Version; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.test.OpenSearchTestCase; @@ -19,6 +23,8 @@ import java.util.List; import java.util.Map; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + public class TemplateTests extends OpenSearchTestCase { private String expectedTemplate = @@ -84,6 +90,58 @@ public void testTemplate() throws IOException { assertEquals(now, template.lastUpdatedTime()); assertNull(template.lastProvisionedTime()); assertEquals("Workflow [userParams={key=value}, nodes=[A, B], edges=[A->B]]", wfX.toString()); + + // Test invalid field if updating + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + assertThrows(FlowFrameworkException.class, () -> Template.parse(parser, true)); + } + + public void testUpdateExistingTemplate() { + // Time travel to guarantee update increments + Instant now = Instant.now().minusMillis(100); + + Template original = new Template( + "name one", + "description one", + "use case one", + Version.fromString("1.1.1"), + List.of(Version.fromString("1.1.1"), Version.fromString("1.1.1")), + Collections.emptyMap(), + Map.of("uiMetadata", "one"), + null, + now, + now, + null + ); + Template updated = new Template.Builder().name("name two").description("description two").useCase("use case two").build(); + Template merged = Template.updateExistingTemplate(original, updated); + assertEquals("name two", merged.name()); + assertEquals("description two", merged.description()); + assertEquals("use case two", merged.useCase()); + assertEquals("1.1.1", merged.templateVersion().toString()); + assertEquals("1.1.1", merged.compatibilityVersion().get(0).toString()); + assertEquals("1.1.1", merged.compatibilityVersion().get(1).toString()); + assertEquals("one", merged.getUiMetadata().get("uiMetadata")); + + updated = new Template.Builder().templateVersion(Version.fromString("2.2.2")) + .compatibilityVersion(List.of(Version.fromString("2.2.2"), Version.fromString("2.2.2"))) + .uiMetadata(Map.of("uiMetadata", "two")) + .build(); + merged = Template.updateExistingTemplate(original, updated); + assertEquals("name one", merged.name()); + assertEquals("description one", merged.description()); + assertEquals("use case one", merged.useCase()); + assertEquals("2.2.2", merged.templateVersion().toString()); + assertEquals("2.2.2", merged.compatibilityVersion().get(0).toString()); + assertEquals("2.2.2", merged.compatibilityVersion().get(1).toString()); + assertEquals("two", merged.getUiMetadata().get("uiMetadata")); + + assertTrue(merged.lastUpdatedTime().isAfter(now)); } public void testExceptions() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 6224ba5b4..f3759b563 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -116,7 +116,45 @@ public void testFailedUpdateWorkflow() throws Exception { assertTrue( exceptionProvisioned.getMessage().contains("The template can not be updated unless its provisioning state is NOT_STARTED") ); + } + + public void testUpdateWorkflowUsingFields() throws Exception { + Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + Response response = createWorkflow(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + + // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status + Response provisionResponse; + if (!indexExistsWithAdminClient(".plugins-ml-config")) { + assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); + provisionResponse = provisionWorkflow(client(), workflowId); + } else { + provisionResponse = provisionWorkflow(client(), workflowId); + } + assertEquals(RestStatus.OK, TestHelpers.restStatus(provisionResponse)); + getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); + // Attempt to update with update_fields with illegal field + // Fails because contains workflow field + ResponseException exceptionProvisioned = expectThrows( + ResponseException.class, + () -> updateWorkflowWithFields(client(), workflowId, "{\"workflows\":{}}") + ); + assertTrue( + exceptionProvisioned.getMessage().contains("You can not update the field [workflows] without updating the whole template.") + ); + // Change just the name and description + response = updateWorkflowWithFields(client(), workflowId, "{\"name\":\"foo\",\"description\":\"bar\"}"); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + // Get the updated template + response = getWorkflow(client(), workflowId); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + Template t = Template.parse(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8)); + assertEquals("foo", t.name()); + assertEquals("bar", t.description()); } public void testCreateAndProvisionLocalModelWorkflow() throws Exception { diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index b55d6b1f2..7e537566c 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -36,6 +36,7 @@ import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.mockito.ArgumentMatchers.any; @@ -47,6 +48,7 @@ public class RestCreateWorkflowActionTests extends OpenSearchTestCase { private String validTemplate; private String invalidTemplate; + private String validUpdateTemplate; private RestCreateWorkflowAction createWorkflowRestAction; private String createWorkflowPath; private String updateWorkflowPath; @@ -82,9 +84,11 @@ public void setUp() throws Exception { null ); - // Invalid template configuration, wrong field name this.validTemplate = template.toJson(); + // Invalid template configuration, wrong field name this.invalidTemplate = this.validTemplate.replace("use_case", "invalid"); + // Partial update of some fields + this.validUpdateTemplate = "{\"description\":\"new description\",\"ui_metadata\":{\"foo\":\"bar\"}}"; this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); @@ -137,6 +141,78 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception ); } + public void testCreateWorkflowRequestWithUpdateAndProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry(UPDATE_WORKFLOW_FIELDS, "true"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains( + "You can not use both the " + PROVISION_WORKFLOW + " and " + UPDATE_WORKFLOW_FIELDS + " parameters in the same request." + ) + ); + } + + public void testCreateWorkflowRequestWithUpdateAndParams() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(UPDATE_WORKFLOW_FIELDS, "true"), Map.entry("foo", "bar"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse().content().utf8ToString().contains("are permitted unless the provision parameter is set to true.") + ); + } + + public void testUpdateWorkflowRequestWithFullTemplateUpdateAndNoParams() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath(this.updateWorkflowPath) + .withParams(Map.of(UPDATE_WORKFLOW_FIELDS, "true")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains("You can not update the field [workflows] without updating the whole template.") + ); + } + + public void testUpdateWorkflowRequestWithUpdateTemplateUpdateAndNoParams() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath(this.updateWorkflowPath) + .withParams(Map.of(UPDATE_WORKFLOW_FIELDS, "true")) + .withContent(new BytesArray(validUpdateTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + public void testCreateWorkflowRequestWithUseCaseButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 11b620f3d..776ce1460 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -48,6 +48,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; @@ -55,6 +56,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.REGISTER_REMOTE_MODEL; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -352,7 +354,7 @@ public void testFailedToUpdateWorkflow() { ActionListener responseListener = invocation.getArgument(2); responseListener.onFailure(new Exception("failed")); return null; - }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -403,7 +405,7 @@ public void testUpdateWorkflow() { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(2); @@ -418,6 +420,81 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } + public void testUpdateWorkflowWithField() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest updateWorkflow = new WorkflowRequest( + "1", + new Template.Builder().name("new name").description("test").useCase(null).uiMetadata(Map.of("foo", "bar")).build(), + Map.of(UPDATE_WORKFLOW_FIELDS, "true") + ); + + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(template.toJson()); + getListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); + + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + verify(listener, times(1)).onResponse(any()); + + ArgumentCaptor