diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 6958e1f45..4dd583e11 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -84,8 +84,6 @@ public class FlowFrameworkPlugin extends Plugin implements ActionPlugin { private FlowFrameworkSettings flowFrameworkSettings; - private ClusterService clusterService; - /** * Instantiate this plugin. */ @@ -106,7 +104,6 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { Settings settings = environment.settings(); - this.clusterService = clusterService; flowFrameworkSettings = new FlowFrameworkSettings(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client); @@ -141,7 +138,7 @@ public List getRestHandlers( Supplier nodesInCluster ) { return List.of( - new RestCreateWorkflowAction(flowFrameworkSettings, settings, clusterService), + new RestCreateWorkflowAction(flowFrameworkSettings), new RestDeleteWorkflowAction(flowFrameworkSettings), new RestProvisionWorkflowAction(flowFrameworkSettings), new RestDeprovisionWorkflowAction(flowFrameworkSettings), diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index fda210800..5af559607 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -23,6 +23,8 @@ public class FlowFrameworkSettings { private volatile Integer maxWorkflowSteps; /** Max workflows that can be created*/ protected volatile Integer maxWorkflows; + /** Timeout for internal requests*/ + protected volatile TimeValue requestTimeout; /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; @@ -86,10 +88,12 @@ public FlowFrameworkSettings(ClusterService clusterService, Settings settings) { this.retryDuration = TASK_REQUEST_RETRY_DURATION.get(settings); this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); this.maxWorkflows = MAX_WORKFLOWS.get(settings); + this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FLOW_FRAMEWORK_ENABLED, it -> isFlowFrameworkEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_REQUEST_RETRY_DURATION, it -> retryDuration = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); } /** @@ -123,4 +127,12 @@ public Integer getMaxWorkflowSteps() { public Integer getMaxWorkflows() { return maxWorkflows; } + + /** + * Getter for request timeout + * @return request timeout + */ + public TimeValue getRequestTimeout() { + return requestTimeout; + } } diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java deleted file mode 100644 index 59c8d5fb0..000000000 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.rest; - -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.rest.BaseRestHandler; - -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; - -/** - * Abstract action for the rest actions - */ -public abstract class AbstractWorkflowAction extends BaseRestHandler { - /** Timeout for the request*/ - protected volatile TimeValue requestTimeout; - - /** - * Instantiates a new AbstractWorkflowAction - * - * @param settings Environment settings - * @param clusterService clusterService - */ - public AbstractWorkflowAction(Settings settings, ClusterService clusterService) { - this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); - - clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); - } - -} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 98cffd050..ed0ae670b 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -11,8 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -23,6 +21,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -40,7 +39,7 @@ /** * Rest Action to facilitate requests to create and update a use case template */ -public class RestCreateWorkflowAction extends AbstractWorkflowAction { +public class RestCreateWorkflowAction extends BaseRestHandler { private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; @@ -50,11 +49,8 @@ public class RestCreateWorkflowAction extends AbstractWorkflowAction { /** * Instantiates a new RestCreateWorkflowAction * @param flowFrameworkSettings The settings for the flow framework plugin - * @param settings Environment settings - * @param clusterService clusterService */ - public RestCreateWorkflowAction(FlowFrameworkSettings flowFrameworkSettings, Settings settings, ClusterService clusterService) { - super(settings, clusterService); + public RestCreateWorkflowAction(FlowFrameworkSettings flowFrameworkSettings) { this.flowFrameworkSettings = flowFrameworkSettings; } @@ -92,7 +88,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, requestTimeout); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 7d836d9d3..fa164c198 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -116,94 +116,104 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - if (!max) { - String errorMessage = "Maximum workflows limit reached " + flowFrameworkSettings.getMaxWorkflows(); - logger.error(errorMessage); - FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); - listener.onFailure(ffe); - return; - } else { - // Initialize config index and create new global context and state index entries - flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { - if (!isInitialized) { - listener.onFailure( - new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) - ); - } else { - // Create new global context and state index entries - flowFrameworkIndicesHandler.putTemplateToGlobalContext( - templateWithUser, - ActionListener.wrap(globalContextResponse -> { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - if (request.isProvision()) { - logger.info("provision parameter"); - WorkflowRequest workflowRequest = new WorkflowRequest(globalContextResponse.getId(), null); - client.execute( - ProvisionWorkflowAction.INSTANCE, - workflowRequest, - ActionListener.wrap(provisionResponse -> { - listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); - }, exception -> { - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) - ); - } - logger.error("Failed to send back provision workflow exception", exception); - }) - ); - } else { - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); - } - }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) - ); - } - }) - ); - }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + checkMaxWorkflows( + flowFrameworkSettings.getRequestTimeout(), + flowFrameworkSettings.getMaxWorkflows(), + ActionListener.wrap(max -> { + if (!max) { + String errorMessage = "Maximum workflows limit reached " + flowFrameworkSettings.getMaxWorkflows(); + logger.error(errorMessage); + FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + listener.onFailure(ffe); + return; + } else { + // Initialize config index and create new global context and state index entries + flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { + if (!isInitialized) { + listener.onFailure( + new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) + ); + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext( + templateWithUser, + ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + if (request.isProvision()) { + logger.info("provision parameter"); + WorkflowRequest workflowRequest = new WorkflowRequest( + globalContextResponse.getId(), + null + ); + client.execute( + ProvisionWorkflowAction.INSTANCE, + workflowRequest, + ActionListener.wrap(provisionResponse -> { + listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException( + exception.getMessage(), + RestStatus.BAD_REQUEST + ) + ); + } + logger.error("Failed to send back provision workflow exception", exception); + }) + ); + } else { + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + } + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + }) ); - } + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } - }) - ); - } - }, exception -> { - logger.error("Failed to initialize config index : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); - } + }) + ); + } + }, exception -> { + logger.error("Failed to initialize config index : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } - })); - } - }, e -> { - logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), e.getMessage()); - if (e instanceof FlowFrameworkException) { - listener.onFailure(e); - } else { - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - })); + })); + } + }, e -> { + logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), e.getMessage()); + if (e instanceof FlowFrameworkException) { + listener.onFailure(e); + } else { + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }) + ); } else { // Update existing entry, full document replacement flowFrameworkIndicesHandler.updateTemplateInGlobalContext( diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 4a517bf87..341c79742 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -11,7 +11,6 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; -import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -43,29 +42,13 @@ public class WorkflowRequest extends ActionRequest { */ private boolean provision; - /** - * Timeout for request - */ - private TimeValue requestTimeout; - /** * Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, null); - } - - /** - * Instantiates a new WorkflowRequest and set validation to false - * @param workflowId the documentId of the workflow - * @param template the use case template which describes the workflow - * @param requestTimeout timeout of the request - * @param maxWorkflows max number of workflows - */ - public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, @Nullable TimeValue requestTimeout) { - this(workflowId, template, new String[] { "all" }, false, requestTimeout); + this(workflowId, template, new String[] { "all" }, false); } /** @@ -74,20 +57,12 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param template the use case template which describes the workflow * @param validation flag to indicate if validation is necessary * @param provision flag to indicate if provision is necessary - * @param requestTimeout timeout of the request */ - public WorkflowRequest( - @Nullable String workflowId, - @Nullable Template template, - String[] validation, - boolean provision, - @Nullable TimeValue requestTimeout - ) { + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision) { this.workflowId = workflowId; this.template = template; this.validation = validation; this.provision = provision; - this.requestTimeout = requestTimeout; } /** @@ -102,7 +77,6 @@ public WorkflowRequest(StreamInput in) throws IOException { this.template = templateJson == null ? null : Template.parse(templateJson); this.validation = in.readStringArray(); this.provision = in.readBoolean(); - this.requestTimeout = in.readOptionalTimeValue(); } /** @@ -139,14 +113,6 @@ public boolean isProvision() { return this.provision; } - /** - * Gets the timeout of the request - * @return the requestTimeout - */ - public TimeValue getRequestTimeout() { - return requestTimeout; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -154,12 +120,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(template == null ? null : template.toJson()); out.writeStringArray(validation); out.writeBoolean(provision); - out.writeOptionalTimeValue(requestTimeout); } @Override public ActionRequestValidationException validate() { return null; } - } diff --git a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java index db48a37e1..20010cdea 100644 --- a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java +++ b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java @@ -42,7 +42,8 @@ public void setUp() throws Exception { FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED, FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION, FlowFrameworkSettings.MAX_WORKFLOW_STEPS, - FlowFrameworkSettings.MAX_WORKFLOWS + FlowFrameworkSettings.MAX_WORKFLOWS, + FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); @@ -61,5 +62,6 @@ public void testSettings() throws IOException { assertEquals(Optional.of(TimeValue.timeValueSeconds(5)), Optional.ofNullable(flowFrameworkSettings.getRetryDuration())); assertEquals(Optional.of(50), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflowSteps())); assertEquals(Optional.of(1000), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflows())); + assertEquals(Optional.of(TimeValue.timeValueSeconds(10)), Optional.ofNullable(flowFrameworkSettings.getRequestTimeout())); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 818c38c00..fcdaf5757 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -10,10 +10,6 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -35,10 +31,7 @@ import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class RestCreateWorkflowActionTests extends OpenSearchTestCase { @@ -49,21 +42,11 @@ public class RestCreateWorkflowActionTests extends OpenSearchTestCase { private String updateWorkflowPath; private NodeClient nodeClient; private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; - private Settings settings; - private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); - settings = Settings.builder() - .put(WORKFLOW_REQUEST_TIMEOUT.getKey(), TimeValue.timeValueMillis(10)) - .put(MAX_WORKFLOWS.getKey(), 2) - .build(); - - ClusterSettings clusterSettings = TestHelpers.clusterSetting(settings, WORKFLOW_REQUEST_TIMEOUT, MAX_WORKFLOWS); - clusterService = spy(new ClusterService(settings, clusterSettings, null)); - when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); Version templateVersion = Version.fromString("1.0.0"); @@ -88,8 +71,7 @@ public void setUp() throws Exception { // Invalid template configuration, wrong field name this.invalidTemplate = template.toJson().replace("use_case", "invalid"); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService); + this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); this.nodeClient = mock(NodeClient.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index f68b925f5..e3a28a12f 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -8,14 +8,38 @@ */ package org.opensearch.flowframework.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; +import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.REGISTER_REMOTE_MODEL; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.mockito.ArgumentCaptor; import org.opensearch.Version; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -35,42 +59,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.mockito.ArgumentCaptor; - -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; -import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; -import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; -import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL; -import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; -import static org.opensearch.flowframework.common.WorkflowResources.REGISTER_REMOTE_MODEL; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private CreateWorkflowTransportAction createWorkflowTransportAction; @@ -79,9 +67,6 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client; private ThreadPool threadPool; - private ClusterSettings clusterSettings; - private ClusterService clusterService; - private Settings settings; private FlowFrameworkSettings flowFrameworkSettings; private PluginsService pluginsService; @@ -91,19 +76,9 @@ public void setUp() throws Exception { client = mock(Client.class); threadPool = mock(ThreadPool.class); - settings = Settings.builder() - .put("plugins.flow_framework.max_workflows", 2) - .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) - .build(); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) - ).collect(Collectors.toSet()); - clusterSettings = new ClusterSettings(settings, settingsSet); - clusterService = mock(ClusterService.class); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.getMaxWorkflows()).thenReturn(2); + when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10)); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); // Validation functionality should not be invoked in these unit tests, mocking instead @@ -156,7 +131,7 @@ public void testValidation_withoutProvision_Success() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } @@ -216,7 +191,7 @@ public void testValidation_Failed() throws Exception { ActionListener listener = mock(ActionListener.class); // Stub validation failure doThrow(Exception.class).when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); verify(listener, times(1)).onFailure(any()); @@ -225,13 +200,7 @@ public void testValidation_Failed() throws Exception { public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); doAnswer(invocation -> { ActionListener checkMaxWorkflowListener = invocation.getArgument(2); @@ -261,13 +230,7 @@ public void onFailure(Exception e) {} public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -298,13 +261,7 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - new String[] { "off" }, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -392,13 +349,7 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - new String[] { "all" }, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -451,13 +402,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - new String[] { "all" }, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings) - ); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> {