From b142c22177443e59ebbacc25924aae3228964b25 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:16:40 -0800 Subject: [PATCH] [Backport 2.x] Change max retries to retry duration, refactor settings for consistency (#390) Change max retries to retry duration, refactor settings for consistency (#381) * Change max retries to retry duration * Move max workflows setting update consumer to settings class * Move workflow timeout setting update consumer to settings class * Use timeout in other search requests * Improve test coverage --------- (cherry picked from commit a6fb53216e96296589b5b0ab4c7e6977764ef92c) Signed-off-by: Daniel Widdis Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../flowframework/FlowFrameworkPlugin.java | 11 +- .../common/FlowFrameworkSettings.java | 59 ++++-- .../rest/AbstractSearchWorkflowAction.java | 11 +- .../rest/AbstractWorkflowAction.java | 42 ---- .../rest/RestCreateWorkflowAction.java | 31 +-- .../CreateWorkflowTransportAction.java | 192 +++++++++--------- .../transport/WorkflowRequest.java | 63 +----- .../AbstractRetryableWorkflowStep.java | 18 +- .../FlowFrameworkPluginTests.java | 6 +- .../common/FlowFrameworkSettingsTests.java | 11 +- .../model/WorkflowValidatorTests.java | 4 +- .../rest/RestCreateWorkflowActionTests.java | 20 +- .../CreateWorkflowTransportActionTests.java | 103 +++------- .../workflow/DeployModelStepTests.java | 21 +- .../RegisterLocalCustomModelStepTests.java | 20 +- ...RegisterLocalPretrainedModelStepTests.java | 20 +- ...sterLocalSparseEncodingModelStepTests.java | 20 +- .../workflow/WorkflowProcessSorterTests.java | 4 +- 18 files changed, 222 insertions(+), 434 deletions(-) delete mode 100644 src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 48a3f51bc..4dd583e11 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -72,9 +72,9 @@ import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -84,8 +84,6 @@ public class FlowFrameworkPlugin extends Plugin implements ActionPlugin { private FlowFrameworkSettings flowFrameworkSettings; - private ClusterService clusterService; - /** * Instantiate this plugin. */ @@ -106,7 +104,6 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { Settings settings = environment.settings(); - this.clusterService = clusterService; flowFrameworkSettings = new FlowFrameworkSettings(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client); @@ -127,7 +124,7 @@ public Collection createComponents( flowFrameworkSettings ); - return List.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); + return List.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler, flowFrameworkSettings); } @Override @@ -141,7 +138,7 @@ public List getRestHandlers( Supplier nodesInCluster ) { return List.of( - new RestCreateWorkflowAction(flowFrameworkSettings, settings, clusterService), + new RestCreateWorkflowAction(flowFrameworkSettings), new RestDeleteWorkflowAction(flowFrameworkSettings), new RestProvisionWorkflowAction(flowFrameworkSettings), new RestDeprovisionWorkflowAction(flowFrameworkSettings), @@ -168,7 +165,7 @@ public List getRestHandlers( @Override public List> getSettings() { - return List.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY); + return List.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index d5bec1871..5af559607 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -17,17 +17,21 @@ public class FlowFrameworkSettings { private volatile Boolean isFlowFrameworkEnabled; - /** The maximum number of transport request retries */ - private volatile Integer maxRetry; - /** Max workflow steps that can be created*/ + /** The duration between request retries */ + private volatile TimeValue retryDuration; + /** Max workflow steps that can be created */ private volatile Integer maxWorkflowSteps; + /** Max workflows that can be created*/ + protected volatile Integer maxWorkflows; + /** Timeout for internal requests*/ + protected volatile TimeValue requestTimeout; /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; /** The upper limit of max workflow steps that can be in a single workflow */ public static final int MAX_WORKFLOW_STEPS_LIMIT = 500; - /** This setting sets max workflows that can be created */ + /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOWS = Setting.intSetting( "plugins.flow_framework.max_workflows", 1000, @@ -37,7 +41,7 @@ public class FlowFrameworkSettings { Setting.Property.Dynamic ); - /** This setting sets max workflows that can be created */ + /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOW_STEPS = Setting.intSetting( "plugins.flow_framework.max_workflow_steps", 50, @@ -47,7 +51,7 @@ public class FlowFrameworkSettings { Setting.Property.Dynamic ); - /** This setting sets the timeout for the request */ + /** This setting sets the timeout for the request */ public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( "plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10), @@ -63,11 +67,10 @@ public class FlowFrameworkSettings { Setting.Property.Dynamic ); - /** This setting sets the maximum number of get task request retries */ - public static final Setting MAX_GET_TASK_REQUEST_RETRY = Setting.intSetting( - "plugins.flow_framework.max_get_task_request_retry", - 5, - 0, + /** This setting sets the time between task request retries */ + public static final Setting TASK_REQUEST_RETRY_DURATION = Setting.positiveTimeSetting( + "plugins.flow_framework.task_request_retry_duration", + TimeValue.timeValueSeconds(5), Setting.Property.NodeScope, Setting.Property.Dynamic ); @@ -82,15 +85,19 @@ public FlowFrameworkSettings(ClusterService clusterService, Settings settings) { // Currently this is just an on/off switch for the entire plugin's API. // If desired more fine-tuned feature settings can be added below. this.isFlowFrameworkEnabled = FLOW_FRAMEWORK_ENABLED.get(settings); - this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings); + this.retryDuration = TASK_REQUEST_RETRY_DURATION.get(settings); this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + this.maxWorkflows = MAX_WORKFLOWS.get(settings); + this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FLOW_FRAMEWORK_ENABLED, it -> isFlowFrameworkEnabled = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_REQUEST_RETRY_DURATION, it -> retryDuration = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); } /** - * Whether the flow framework feature is enabled. If disabled, no REST APIs will be availble. + * Whether the flow framework feature is enabled. If disabled, no REST APIs will be available. * @return whether Flow Framework is enabled. */ public boolean isFlowFrameworkEnabled() { @@ -98,11 +105,11 @@ public boolean isFlowFrameworkEnabled() { } /** - * Getter for max retry count - * @return count of max retry + * Getter for retry duration + * @return retry duration */ - public Integer getMaxRetry() { - return maxRetry; + public TimeValue getRetryDuration() { + return retryDuration; } /** @@ -112,4 +119,20 @@ public Integer getMaxRetry() { public Integer getMaxWorkflowSteps() { return maxWorkflowSteps; } + + /** + * Getter for max workflows + * @return max workflows + */ + public Integer getMaxWorkflows() { + return maxWorkflows; + } + + /** + * Getter for request timeout + * @return request timeout + */ + public TimeValue getRequestTimeout() { + return requestTimeout; + } } diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index e5ad65c43..0e62de62f 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -48,7 +48,7 @@ public abstract class AbstractSearchWorkflowAction e /** Search action type*/ protected final ActionType actionType; /** Settings to enable FlowFramework API*/ - protected final FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; + protected final FlowFrameworkSettings flowFrameworkSettings; /** * Instantiates a new AbstractSearchWorkflowAction @@ -56,25 +56,25 @@ public abstract class AbstractSearchWorkflowAction e * @param index index the search should be done on * @param clazz model class * @param actionType from which action abstract class is called - * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + * @param flowFrameworkSettings Whether this API is enabled */ public AbstractSearchWorkflowAction( List urlPaths, String index, Class clazz, ActionType actionType, - FlowFrameworkSettings flowFrameworkFeatureEnabledSetting + FlowFrameworkSettings flowFrameworkSettings ) { this.urlPaths = urlPaths; this.index = index; this.clazz = clazz; this.actionType = actionType; - this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + this.flowFrameworkSettings = flowFrameworkSettings; } @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", RestStatus.FORBIDDEN @@ -87,6 +87,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); return channel -> client.execute(actionType, searchRequest, search(channel)); } diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java deleted file mode 100644 index 251ce750e..000000000 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.rest; - -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.rest.BaseRestHandler; - -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; - -/** - * Abstract action for the rest actions - */ -public abstract class AbstractWorkflowAction extends BaseRestHandler { - /** Timeout for the request*/ - protected volatile TimeValue requestTimeout; - /** Max workflows that can be created*/ - protected volatile Integer maxWorkflows; - - /** - * Instantiates a new AbstractWorkflowAction - * - * @param settings Environment settings - * @param clusterService clusterService - */ - public AbstractWorkflowAction(Settings settings, ClusterService clusterService) { - this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); - this.maxWorkflows = MAX_WORKFLOWS.get(settings); - - clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); - } - -} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 6a0558e89..ed0ae670b 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -11,8 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -23,6 +21,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -40,26 +39,19 @@ /** * Rest Action to facilitate requests to create and update a use case template */ -public class RestCreateWorkflowAction extends AbstractWorkflowAction { +public class RestCreateWorkflowAction extends BaseRestHandler { private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; - private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; + private FlowFrameworkSettings flowFrameworkSettings; /** * Instantiates a new RestCreateWorkflowAction - * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled - * @param settings Environment settings - * @param clusterService clusterService + * @param flowFrameworkSettings The settings for the flow framework plugin */ - public RestCreateWorkflowAction( - FlowFrameworkSettings flowFrameworkFeatureEnabledSetting, - Settings settings, - ClusterService clusterService - ) { - super(settings, clusterService); - this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + public RestCreateWorkflowAction(FlowFrameworkSettings flowFrameworkSettings) { + this.flowFrameworkSettings = flowFrameworkSettings; } @Override @@ -80,7 +72,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", RestStatus.FORBIDDEN @@ -96,14 +88,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - WorkflowRequest workflowRequest = new WorkflowRequest( - workflowId, - template, - validation, - provision, - requestTimeout, - maxWorkflows - ); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 671cff688..2fd9bd042 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -16,12 +16,12 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -55,7 +55,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction { - if (!max) { - String errorMessage = "Maximum workflows limit reached " + request.getMaxWorkflows(); - logger.error(errorMessage); - FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); - listener.onFailure(ffe); - return; - } else { - // Initialize config index and create new global context and state index entries - flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { - if (!isInitialized) { - listener.onFailure( - new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) - ); - } else { - // Create new global context and state index entries - flowFrameworkIndicesHandler.putTemplateToGlobalContext( - templateWithUser, - ActionListener.wrap(globalContextResponse -> { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - if (request.isProvision()) { - logger.info("provision parameter"); - WorkflowRequest workflowRequest = new WorkflowRequest(globalContextResponse.getId(), null); - client.execute( - ProvisionWorkflowAction.INSTANCE, - workflowRequest, - ActionListener.wrap(provisionResponse -> { - listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); - }, exception -> { - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) - ); - } - logger.error("Failed to send back provision workflow exception", exception); - }) - ); - } else { - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); - } - }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) - ); - } - }) - ); - }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + checkMaxWorkflows( + flowFrameworkSettings.getRequestTimeout(), + flowFrameworkSettings.getMaxWorkflows(), + ActionListener.wrap(max -> { + if (!max) { + String errorMessage = "Maximum workflows limit reached " + flowFrameworkSettings.getMaxWorkflows(); + logger.error(errorMessage); + FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + listener.onFailure(ffe); + return; + } else { + // Initialize config index and create new global context and state index entries + flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { + if (!isInitialized) { + listener.onFailure( + new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) + ); + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext( + templateWithUser, + ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + if (request.isProvision()) { + logger.info("provision parameter"); + WorkflowRequest workflowRequest = new WorkflowRequest( + globalContextResponse.getId(), + null + ); + client.execute( + ProvisionWorkflowAction.INSTANCE, + workflowRequest, + ActionListener.wrap(provisionResponse -> { + listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException( + exception.getMessage(), + RestStatus.BAD_REQUEST + ) + ); + } + logger.error("Failed to send back provision workflow exception", exception); + }) + ); + } else { + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + } + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + }) ); - } + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } - }) - ); - } - }, exception -> { - logger.error("Failed to initialize config index : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); - } + }) + ); + } + }, exception -> { + logger.error("Failed to initialize config index : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } - })); - } - }, e -> { - logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), e.getMessage()); - if (e instanceof FlowFrameworkException) { - listener.onFailure(e); - } else { - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - })); + })); + } + }, e -> { + logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), e.getMessage()); + if (e instanceof FlowFrameworkException) { + listener.onFailure(e); + } else { + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }) + ); } else { // Update existing entry, full document replacement flowFrameworkIndicesHandler.updateTemplateInGlobalContext( @@ -247,7 +257,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener internalListener) { + void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionListener internalListener) { if (!flowFrameworkIndicesHandler.doesIndexExist(CommonValue.GLOBAL_CONTEXT_INDEX)) { internalListener.onResponse(true); } else { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 72657e854..341c79742 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -11,7 +11,6 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; -import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -43,39 +42,13 @@ public class WorkflowRequest extends ActionRequest { */ private boolean provision; - /** - * Timeout for request - */ - private TimeValue requestTimeout; - - /** - * Max workflows - */ - private Integer maxWorkflows; - /** * Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, null, null); - } - - /** - * Instantiates a new WorkflowRequest and set validation to false - * @param workflowId the documentId of the workflow - * @param template the use case template which describes the workflow - * @param requestTimeout timeout of the request - * @param maxWorkflows max number of workflows - */ - public WorkflowRequest( - @Nullable String workflowId, - @Nullable Template template, - @Nullable TimeValue requestTimeout, - @Nullable Integer maxWorkflows - ) { - this(workflowId, template, new String[] { "all" }, false, requestTimeout, maxWorkflows); + this(workflowId, template, new String[] { "all" }, false); } /** @@ -84,23 +57,12 @@ public WorkflowRequest( * @param template the use case template which describes the workflow * @param validation flag to indicate if validation is necessary * @param provision flag to indicate if provision is necessary - * @param requestTimeout timeout of the request - * @param maxWorkflows max number of workflows */ - public WorkflowRequest( - @Nullable String workflowId, - @Nullable Template template, - String[] validation, - boolean provision, - @Nullable TimeValue requestTimeout, - @Nullable Integer maxWorkflows - ) { + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision) { this.workflowId = workflowId; this.template = template; this.validation = validation; this.provision = provision; - this.requestTimeout = requestTimeout; - this.maxWorkflows = maxWorkflows; } /** @@ -115,8 +77,6 @@ public WorkflowRequest(StreamInput in) throws IOException { this.template = templateJson == null ? null : Template.parse(templateJson); this.validation = in.readStringArray(); this.provision = in.readBoolean(); - this.requestTimeout = in.readOptionalTimeValue(); - this.maxWorkflows = in.readOptionalInt(); } /** @@ -153,22 +113,6 @@ public boolean isProvision() { return this.provision; } - /** - * Gets the timeout of the request - * @return the requestTimeout - */ - public TimeValue getRequestTimeout() { - return requestTimeout; - } - - /** - * Gets the max workflows - * @return the maxWorkflows - */ - public Integer getMaxWorkflows() { - return maxWorkflows; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -176,13 +120,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(template == null ? null : template.toJson()); out.writeStringArray(validation); out.writeBoolean(provision); - out.writeOptionalTimeValue(requestTimeout); - out.writeOptionalInt(maxWorkflows); } @Override public ActionRequestValidationException validate() { return null; } - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index fa2512c70..c277ab27d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -22,7 +23,6 @@ import org.opensearch.threadpool.ThreadPool; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicInteger; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -32,8 +32,7 @@ */ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(AbstractRetryableWorkflowStep.class); - /** The maximum number of transport request retries */ - protected volatile Integer maxRetry; + private TimeValue retryDuration; private final MachineLearningNodeClient mlClient; private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private ThreadPool threadPool; @@ -52,7 +51,7 @@ protected AbstractRetryableWorkflowStep( FlowFrameworkSettings flowFrameworkSettings ) { this.threadPool = threadPool; - this.maxRetry = flowFrameworkSettings.getMaxRetry(); + this.retryDuration = flowFrameworkSettings.getRetryDuration(); this.mlClient = mlClient; this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -74,9 +73,8 @@ protected void retryableGetMlTask( String workflowStep, ActionListener mlTaskListener ) { - AtomicInteger retries = new AtomicInteger(); CompletableFuture.runAsync(() -> { - while (retries.getAndIncrement() < this.maxRetry && !future.isDone()) { + while (!future.isDone()) { mlClient.getTask(taskId, ActionListener.wrap(response -> { switch (response.getState()) { case COMPLETED: @@ -123,19 +121,13 @@ protected void retryableGetMlTask( logger.error(errorMessage); mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); })); - // Wait long enough for future to possibly complete try { - Thread.sleep(5000); + Thread.sleep(this.retryDuration.getMillis()); } catch (InterruptedException e) { FutureUtils.cancel(future); Thread.currentThread().interrupt(); } } - if (!future.isDone()) { - String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries"; - logger.error(errorMessage); - mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); - } }, threadPool.executor(WORKFLOW_THREAD_POOL)); } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index b08c27cfb..6970beb14 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -27,9 +27,9 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -63,7 +63,7 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -79,7 +79,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 4, + 5, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(8, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); diff --git a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java index 04754f274..20010cdea 100644 --- a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java +++ b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java @@ -12,6 +12,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -39,8 +40,10 @@ public void setUp() throws Exception { ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Stream.of( FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED, - FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY, - FlowFrameworkSettings.MAX_WORKFLOW_STEPS + FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION, + FlowFrameworkSettings.MAX_WORKFLOW_STEPS, + FlowFrameworkSettings.MAX_WORKFLOWS, + FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); @@ -56,7 +59,9 @@ public void tearDown() throws Exception { public void testSettings() throws IOException { assertFalse(flowFrameworkSettings.isFlowFrameworkEnabled()); - assertEquals(Optional.of(5), Optional.ofNullable(flowFrameworkSettings.getMaxRetry())); + assertEquals(Optional.of(TimeValue.timeValueSeconds(5)), Optional.ofNullable(flowFrameworkSettings.getRetryDuration())); assertEquals(Optional.of(50), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflowSteps())); + assertEquals(Optional.of(1000), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflows())); + assertEquals(Optional.of(TimeValue.timeValueSeconds(10)), Optional.ofNullable(flowFrameworkSettings.getRequestTimeout())); } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index de5ae440b..1b085ae1e 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -31,8 +31,8 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -89,7 +89,7 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) ).collect(Collectors.toSet()); ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 818c38c00..fcdaf5757 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -10,10 +10,6 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -35,10 +31,7 @@ import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class RestCreateWorkflowActionTests extends OpenSearchTestCase { @@ -49,21 +42,11 @@ public class RestCreateWorkflowActionTests extends OpenSearchTestCase { private String updateWorkflowPath; private NodeClient nodeClient; private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; - private Settings settings; - private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); - settings = Settings.builder() - .put(WORKFLOW_REQUEST_TIMEOUT.getKey(), TimeValue.timeValueMillis(10)) - .put(MAX_WORKFLOWS.getKey(), 2) - .build(); - - ClusterSettings clusterSettings = TestHelpers.clusterSetting(settings, WORKFLOW_REQUEST_TIMEOUT, MAX_WORKFLOWS); - clusterService = spy(new ClusterService(settings, clusterSettings, null)); - when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); Version templateVersion = Version.fromString("1.0.0"); @@ -88,8 +71,7 @@ public void setUp() throws Exception { // Invalid template configuration, wrong field name this.invalidTemplate = template.toJson().replace("use_case", "invalid"); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService); + this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); this.nodeClient = mock(NodeClient.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 61a39809d..ce08cdb8a 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -8,20 +8,21 @@ */ package org.opensearch.flowframework.transport; +import org.apache.lucene.search.TotalHits; import org.opensearch.Version; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -29,6 +30,8 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -37,21 +40,13 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL; @@ -78,9 +73,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client; private ThreadPool threadPool; - private ClusterSettings clusterSettings; - private ClusterService clusterService; - private Settings settings; + private FlowFrameworkSettings flowFrameworkSettings; private PluginsService pluginsService; @Override @@ -89,17 +82,9 @@ public void setUp() throws Exception { client = mock(Client.class); threadPool = mock(ThreadPool.class); - settings = Settings.builder() - .put("plugins.flow_framework.max_workflows", 2) - .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) - .build(); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - clusterSettings = new ClusterSettings(settings, settingsSet); - clusterService = mock(ClusterService.class); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); + when(flowFrameworkSettings.getMaxWorkflows()).thenReturn(2); + when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10)); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); // Validation functionality should not be invoked in these unit tests, mocking instead @@ -113,7 +98,7 @@ public void setUp() throws Exception { mock(ActionFilters.class), workflowProcessSorter, flowFrameworkIndicesHandler, - settings, + flowFrameworkSettings, client, pluginsService ) @@ -152,7 +137,7 @@ public void testValidation_withoutProvision_Success() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } @@ -212,29 +197,27 @@ public void testValidation_Failed() throws Exception { ActionListener listener = mock(ActionListener.class); // Stub validation failure doThrow(Exception.class).when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); verify(listener, times(1)).onFailure(any()); } public void testMaxWorkflow() { + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); doAnswer(invocation -> { - ActionListener checkMaxWorkflowListener = invocation.getArgument(2); - checkMaxWorkflowListener.onResponse(false); + ActionListener searchListener = invocation.getArgument(1); + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + searchListener.onResponse(searchResponse); return null; - }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + }).when(client).search(any(SearchRequest.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -243,6 +226,8 @@ public void testMaxWorkflow() { } public void testMaxWorkflowWithNoIndex() { + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(false); + ActionListener listener = new ActionListener() { @Override public void onResponse(Boolean booleanResponse) { @@ -250,7 +235,9 @@ public void onResponse(Boolean booleanResponse) { } @Override - public void onFailure(Exception e) {} + public void onFailure(Exception e) { + fail("Should call onResponse"); + } }; createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), 10, listener); } @@ -258,14 +245,7 @@ public void onFailure(Exception e) {} public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -296,14 +276,7 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -391,14 +364,7 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - new String[] { "all" }, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -451,14 +417,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - new String[] { "all" }, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 90f2afc8d..2d17da062 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -11,10 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -36,12 +34,9 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -52,7 +47,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -82,20 +76,9 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - ClusterService clusterService = mock(ClusterService.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - - // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); - ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); - when(flowFrameworkSettings.getMaxRetry()).thenReturn(5); + when(flowFrameworkSettings.getRetryDuration()).thenReturn(TimeValue.timeValueSeconds(5)); testThreadPool = new TestThreadPool( DeployModelStepTests.class.getName(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 0f4eeaaa5..942279e6a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -11,10 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -34,12 +32,9 @@ import java.util.Collections; import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -49,7 +44,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -77,20 +71,10 @@ public void setUp() throws Exception { super.setUp(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - ClusterService clusterService = mock(ClusterService.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - - // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); - ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); - when(flowFrameworkSettings.getMaxRetry()).thenReturn(5); + when(flowFrameworkSettings.getRetryDuration()).thenReturn(TimeValue.timeValueSeconds(5)); testThreadPool = new TestThreadPool( RegisterLocalCustomModelStepTests.class.getName(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 5f61da559..afb76d92a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -11,10 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -34,12 +32,9 @@ import java.util.Collections; import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -49,7 +44,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -77,20 +71,10 @@ public void setUp() throws Exception { super.setUp(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - ClusterService clusterService = mock(ClusterService.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - - // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); - ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); - when(flowFrameworkSettings.getMaxRetry()).thenReturn(5); + when(flowFrameworkSettings.getRetryDuration()).thenReturn(TimeValue.timeValueSeconds(5)); testThreadPool = new TestThreadPool( RegisterLocalCustomModelStepTests.class.getName(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 8110ebdd0..6756913ec 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -11,10 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -34,12 +32,9 @@ import java.util.Collections; import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -49,7 +44,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -77,20 +71,10 @@ public void setUp() throws Exception { super.setUp(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - ClusterService clusterService = mock(ClusterService.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(MAX_GET_TASK_REQUEST_RETRY) - ).collect(Collectors.toSet()); - - // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); - ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); - when(flowFrameworkSettings.getMaxRetry()).thenReturn(5); + when(flowFrameworkSettings.getRetryDuration()).thenReturn(TimeValue.timeValueSeconds(5)); testThreadPool = new TestThreadPool( RegisterLocalCustomModelStepTests.class.getName(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index ccf47ac70..2d32bdb04 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -46,9 +46,9 @@ import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -98,7 +98,7 @@ public static void setup() throws IOException { Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) ).collect(Collectors.toSet()); ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings);