diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index a51082e0e..6958e1f45 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -127,7 +127,7 @@ public Collection createComponents( flowFrameworkSettings ); - return List.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); + return List.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler, flowFrameworkSettings); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 780aa165b..fda210800 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -21,6 +21,8 @@ public class FlowFrameworkSettings { private volatile TimeValue retryDuration; /** Max workflow steps that can be created */ private volatile Integer maxWorkflowSteps; + /** Max workflows that can be created*/ + protected volatile Integer maxWorkflows; /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; @@ -83,9 +85,11 @@ public FlowFrameworkSettings(ClusterService clusterService, Settings settings) { this.isFlowFrameworkEnabled = FLOW_FRAMEWORK_ENABLED.get(settings); this.retryDuration = TASK_REQUEST_RETRY_DURATION.get(settings); this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + this.maxWorkflows = MAX_WORKFLOWS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FLOW_FRAMEWORK_ENABLED, it -> isFlowFrameworkEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_REQUEST_RETRY_DURATION, it -> retryDuration = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); } /** @@ -111,4 +115,12 @@ public TimeValue getRetryDuration() { public Integer getMaxWorkflowSteps() { return maxWorkflowSteps; } + + /** + * Getter for max workflows + * @return max workflows + */ + public Integer getMaxWorkflows() { + return maxWorkflows; + } } diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java index 251ce750e..59c8d5fb0 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java @@ -13,7 +13,6 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.rest.BaseRestHandler; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -22,8 +21,6 @@ public abstract class AbstractWorkflowAction extends BaseRestHandler { /** Timeout for the request*/ protected volatile TimeValue requestTimeout; - /** Max workflows that can be created*/ - protected volatile Integer maxWorkflows; /** * Instantiates a new AbstractWorkflowAction @@ -33,10 +30,8 @@ public abstract class AbstractWorkflowAction extends BaseRestHandler { */ public AbstractWorkflowAction(Settings settings, ClusterService clusterService) { this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); - this.maxWorkflows = MAX_WORKFLOWS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 6a0558e89..98cffd050 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -45,21 +45,17 @@ public class RestCreateWorkflowAction extends AbstractWorkflowAction { private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; - private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; + private FlowFrameworkSettings flowFrameworkSettings; /** * Instantiates a new RestCreateWorkflowAction - * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + * @param flowFrameworkSettings The settings for the flow framework plugin * @param settings Environment settings * @param clusterService clusterService */ - public RestCreateWorkflowAction( - FlowFrameworkSettings flowFrameworkFeatureEnabledSetting, - Settings settings, - ClusterService clusterService - ) { + public RestCreateWorkflowAction(FlowFrameworkSettings flowFrameworkSettings, Settings settings, ClusterService clusterService) { super(settings, clusterService); - this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + this.flowFrameworkSettings = flowFrameworkSettings; } @Override @@ -80,7 +76,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", RestStatus.FORBIDDEN @@ -96,14 +92,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - WorkflowRequest workflowRequest = new WorkflowRequest( - workflowId, - template, - validation, - provision, - requestTimeout, - maxWorkflows - ); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, requestTimeout); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 671cff688..7d836d9d3 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -16,12 +16,12 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -55,7 +55,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction { + checkMaxWorkflows(request.getRequestTimeout(), flowFrameworkSettings.getMaxWorkflows(), ActionListener.wrap(max -> { if (!max) { - String errorMessage = "Maximum workflows limit reached " + request.getMaxWorkflows(); + String errorMessage = "Maximum workflows limit reached " + flowFrameworkSettings.getMaxWorkflows(); logger.error(errorMessage); FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); listener.onFailure(ffe); diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 72657e854..4a517bf87 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -48,18 +48,13 @@ public class WorkflowRequest extends ActionRequest { */ private TimeValue requestTimeout; - /** - * Max workflows - */ - private Integer maxWorkflows; - /** * Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, null, null); + this(workflowId, template, new String[] { "all" }, false, null); } /** @@ -69,13 +64,8 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param requestTimeout timeout of the request * @param maxWorkflows max number of workflows */ - public WorkflowRequest( - @Nullable String workflowId, - @Nullable Template template, - @Nullable TimeValue requestTimeout, - @Nullable Integer maxWorkflows - ) { - this(workflowId, template, new String[] { "all" }, false, requestTimeout, maxWorkflows); + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, @Nullable TimeValue requestTimeout) { + this(workflowId, template, new String[] { "all" }, false, requestTimeout); } /** @@ -85,22 +75,19 @@ public WorkflowRequest( * @param validation flag to indicate if validation is necessary * @param provision flag to indicate if provision is necessary * @param requestTimeout timeout of the request - * @param maxWorkflows max number of workflows */ public WorkflowRequest( @Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision, - @Nullable TimeValue requestTimeout, - @Nullable Integer maxWorkflows + @Nullable TimeValue requestTimeout ) { this.workflowId = workflowId; this.template = template; this.validation = validation; this.provision = provision; this.requestTimeout = requestTimeout; - this.maxWorkflows = maxWorkflows; } /** @@ -116,7 +103,6 @@ public WorkflowRequest(StreamInput in) throws IOException { this.validation = in.readStringArray(); this.provision = in.readBoolean(); this.requestTimeout = in.readOptionalTimeValue(); - this.maxWorkflows = in.readOptionalInt(); } /** @@ -161,14 +147,6 @@ public TimeValue getRequestTimeout() { return requestTimeout; } - /** - * Gets the max workflows - * @return the maxWorkflows - */ - public Integer getMaxWorkflows() { - return maxWorkflows; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -177,7 +155,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringArray(validation); out.writeBoolean(provision); out.writeOptionalTimeValue(requestTimeout); - out.writeOptionalInt(maxWorkflows); } @Override diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 7a03077d9..6970beb14 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -79,7 +79,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 4, + 5, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(8, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); diff --git a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java index 430d11508..db48a37e1 100644 --- a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java +++ b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java @@ -41,7 +41,8 @@ public void setUp() throws Exception { Stream.of( FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED, FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION, - FlowFrameworkSettings.MAX_WORKFLOW_STEPS + FlowFrameworkSettings.MAX_WORKFLOW_STEPS, + FlowFrameworkSettings.MAX_WORKFLOWS ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); @@ -59,5 +60,6 @@ public void testSettings() throws IOException { assertFalse(flowFrameworkSettings.isFlowFrameworkEnabled()); assertEquals(Optional.of(TimeValue.timeValueSeconds(5)), Optional.ofNullable(flowFrameworkSettings.getRetryDuration())); assertEquals(Optional.of(50), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflowSteps())); + assertEquals(Optional.of(1000), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflows())); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 718e7a31c..f68b925f5 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -81,6 +82,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private ClusterSettings clusterSettings; private ClusterService clusterService; private Settings settings; + private FlowFrameworkSettings flowFrameworkSettings; private PluginsService pluginsService; @Override @@ -100,6 +102,8 @@ public void setUp() throws Exception { clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); + when(flowFrameworkSettings.getMaxWorkflows()).thenReturn(2); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); // Validation functionality should not be invoked in these unit tests, mocking instead @@ -113,7 +117,7 @@ public void setUp() throws Exception { mock(ActionFilters.class), workflowProcessSorter, flowFrameworkIndicesHandler, - settings, + flowFrameworkSettings, client, pluginsService ) @@ -152,7 +156,7 @@ public void testValidation_withoutProvision_Success() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } @@ -212,7 +216,7 @@ public void testValidation_Failed() throws Exception { ActionListener listener = mock(ActionListener.class); // Stub validation failure doThrow(Exception.class).when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); verify(listener, times(1)).onFailure(any()); @@ -226,8 +230,7 @@ public void testMaxWorkflow() { template, new String[] { "off" }, false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + WORKFLOW_REQUEST_TIMEOUT.get(settings) ); doAnswer(invocation -> { @@ -263,8 +266,7 @@ public void testFailedToCreateNewWorkflow() { template, new String[] { "off" }, false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + WORKFLOW_REQUEST_TIMEOUT.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -301,8 +303,7 @@ public void testCreateNewWorkflow() { template, new String[] { "off" }, false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + WORKFLOW_REQUEST_TIMEOUT.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -396,8 +397,7 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc validTemplate, new String[] { "all" }, true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + WORKFLOW_REQUEST_TIMEOUT.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -456,8 +456,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() validTemplate, new String[] { "all" }, true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + WORKFLOW_REQUEST_TIMEOUT.get(settings) ); // Bypass checkMaxWorkflows and force onResponse