From a812e51b505a5453bac58e36101798f96c14828d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 29 Jan 2024 19:17:20 -0800 Subject: [PATCH] Add more thread pools (#465) * Add more thread pools Signed-off-by: Daniel Widdis * Increase maxFailures to 4 Signed-off-by: Daniel Widdis * Wait before starting tests Signed-off-by: Daniel Widdis * Improve ProcessNode timeout Signed-off-by: Daniel Widdis * Increase minimum thread pool requirement Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- build.gradle | 2 +- .../flowframework/FlowFrameworkPlugin.java | 20 +++++- .../flowframework/common/CommonValue.java | 6 +- .../DeprovisionWorkflowTransportAction.java | 6 +- .../ProvisionWorkflowTransportAction.java | 4 +- .../flowframework/workflow/ProcessNode.java | 25 +++----- .../workflow/WorkflowProcessSorter.java | 2 + .../resources/mappings/workflow-steps.json | 3 +- .../FlowFrameworkPluginTests.java | 2 +- .../opensearch/flowframework/TestHelpers.java | 9 +++ .../rest/FlowFrameworkRestApiIT.java | 17 +++++- ...provisionWorkflowTransportActionTests.java | 6 +- .../workflow/DeployModelStepTests.java | 12 +++- .../workflow/ProcessNodeTests.java | 61 ++++++++++++++----- .../RegisterLocalCustomModelStepTests.java | 12 +++- ...RegisterLocalPretrainedModelStepTests.java | 12 +++- ...sterLocalSparseEncodingModelStepTests.java | 12 +++- .../workflow/WorkflowProcessSorterTests.java | 12 ++-- 18 files changed, 162 insertions(+), 61 deletions(-) diff --git a/build.gradle b/build.gradle index fd5f23e94..d3092fd83 100644 --- a/build.gradle +++ b/build.gradle @@ -414,7 +414,7 @@ allprojects { retry { if (System.getenv().containsKey("CI")) { maxRetries = 1 - maxFailures = 3 + maxFailures = 4 failOnPassedAfterRetry = false } } diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index ea1327c1d..779438bac 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -73,7 +73,9 @@ import java.util.List; import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; @@ -185,9 +187,23 @@ public List> getExecutorBuilders(Settings settings) { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(settings), - TimeValue.timeValueMinutes(5), + Math.max(2, OpenSearchExecutors.allocatedProcessors(settings) - 1), + TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + PROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), + TimeValue.timeValueMinutes(5), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + DEPROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(2, OpenSearchExecutors.allocatedProcessors(settings) - 1), + TimeValue.timeValueMinutes(1), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ed3eb0395..5be7482e0 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -70,8 +70,12 @@ private CommonValue() {} */ /** Flow Framework plugin thread pool name prefix */ public static final String FLOW_FRAMEWORK_THREAD_POOL_PREFIX = "thread_pool.flow_framework."; - /** The provision workflow thread pool name */ + /** The general workflow thread pool name for most calls */ public static final String WORKFLOW_THREAD_POOL = "opensearch_workflow"; + /** The workflow thread pool name for provisioning */ + public static final String PROVISION_WORKFLOW_THREAD_POOL = "opensearch_provision_workflow"; + /** The workflow thread pool name for deprovisioning */ + public static final String DEPROVISION_WORKFLOW_THREAD_POOL = "opensearch_deprovision_workflow"; /* * Field names common to multiple classes diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 7383e0f12..821ec39c8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -42,11 +42,11 @@ import java.util.Objects; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -102,7 +102,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener)); }, exception -> { String message = "Failed to get workflow state for workflow " + workflowId; @@ -143,6 +143,7 @@ private void executeDeprovisionSequence( new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId), Collections.emptyList(), this.threadPool, + DEPROVISION_WORKFLOW_THREAD_POOL, flowFrameworkSettings.getRequestTimeout() ) ); @@ -196,6 +197,7 @@ private void executeDeprovisionSequence( pn.input(), pn.predecessors(), this.threadPool, + DEPROVISION_WORKFLOW_THREAD_POOL, pn.nodeTimeout() ); }).collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 05d714055..2ed203c63 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -49,9 +49,9 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; /** * Transport Action to provision a workflow from a stored use case template @@ -180,7 +180,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence, ActionListener listener) { try { - threadPool.executor(WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 82aa2c2d1..2349b9fb7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -12,16 +12,12 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.unit.TimeValue; -import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeoutException; - -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; /** * Representation of a process node in a workflow graph. @@ -37,6 +33,7 @@ public class ProcessNode { private final WorkflowData input; private final List predecessors; private final ThreadPool threadPool; + private final String threadPoolName; private final TimeValue nodeTimeout; private final PlainActionFuture future = PlainActionFuture.newFuture(); @@ -50,6 +47,7 @@ public class ProcessNode { * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow * @param threadPool The OpenSearch thread pool + * @param threadPoolName The thread pool to use * @param nodeTimeout The timeout value for executing on this node */ public ProcessNode( @@ -59,6 +57,7 @@ public ProcessNode( WorkflowData input, List predecessors, ThreadPool threadPool, + String threadPoolName, TimeValue nodeTimeout ) { this.id = id; @@ -67,6 +66,7 @@ public ProcessNode( this.input = input; this.predecessors = predecessors; this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.nodeTimeout = nodeTimeout; } @@ -152,17 +152,9 @@ public PlainActionFuture execute() { WorkflowData wd = node.future().actionGet(); inputMap.put(wd.getNodeId(), wd); } - logger.info("Starting {}.", this.id); - ScheduledCancellable delayExec = null; - if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { - delayExec = threadPool.schedule(() -> { - if (!future.isDone()) { - future.onFailure(new TimeoutException("Execute timed out for " + this.id)); - } - }, this.nodeTimeout, ThreadPool.Names.SAME); - } // record start time for this step. + logger.info("Starting {}.", this.id); PlainActionFuture stepFuture = this.workflowStep.execute( this.id, this.input, @@ -170,16 +162,13 @@ public PlainActionFuture execute() { this.previousNodeInputs ); // If completed exceptionally, this is a no-op - future.onResponse(stepFuture.get()); + future.onResponse(stepFuture.actionGet(this.nodeTimeout)); // record end time passing workflow steps - if (delayExec != null) { - delayExec.cancel(); - } logger.info("Finished {}.", this.id); } catch (Exception e) { this.future.onFailure(e); } - }, threadPool.executor(WORKFLOW_THREAD_POOL)); + }, threadPool.executor(this.threadPoolName)); return this.future; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 437615c45..4b4078098 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -38,6 +38,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; @@ -122,6 +123,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId) data, predecessorNodes, threadPool, + PROVISION_WORKFLOW_THREAD_POOL, nodeTimeout ); idToNodeMap.put(processNode.id(), processNode); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index c46dd3982..f1ff99de6 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -19,7 +19,8 @@ ], "required_plugins":[ "opensearch-ml" - ] + ], + "timeout": "60s" }, "delete_connector": { "inputs": [ diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 51321618e..401ddbe9a 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -84,7 +84,7 @@ public void testPlugin() throws IOException { ); assertEquals(9, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(9, ffp.getActions().size()); - assertEquals(1, ffp.getExecutorBuilders(settings).size()); + assertEquals(3, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } } diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 4a0a055c3..6c4f3534b 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -113,6 +113,15 @@ public static Response makeRequest( if (entity != null) { request.setEntity(entity); } + try { + return client.performRequest(request); + } catch (IOException e) { + // In restricted resource cluster, initialization of REST clients on other nodes takes time + // Wait 10 seconds and try again + try { + Thread.sleep(10000); + } catch (InterruptedException ie) {} + } return client.performRequest(request); } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 408e8f811..1928099bf 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -22,6 +22,7 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.model.WorkflowState; +import org.junit.Before; import org.junit.ComparisonFailure; import java.util.Collections; @@ -30,7 +31,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; @@ -39,6 +42,18 @@ public class FlowFrameworkRestApiIT extends FlowFrameworkRestTestCase { + private static AtomicBoolean waitToStart = new AtomicBoolean(true); + + @Before + public void waitToStart() throws Exception { + // ML Commons cron job runs every 10 seconds and takes 20+ seconds to initialize .plugins-ml-config index + // Delay on the first attempt for 25 seconds to allow this initialization and prevent flaky tests + if (waitToStart.getAndSet(false)) { + CountDownLatch latch = new CountDownLatch(1); + latch.await(25, TimeUnit.SECONDS); + } + } + public void testSearchWorkflows() throws Exception { // Create a Workflow that has a credential 12345 @@ -257,7 +272,7 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { // wait and ensure state is completed/done assertBusy( () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, - 30, + 120, TimeUnit.SECONDS ); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 3582b7974..b14dd2bb1 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -38,8 +38,8 @@ import org.mockito.ArgumentCaptor; +import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -57,11 +57,11 @@ public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase private static ThreadPool threadPool = new TestThreadPool( DeprovisionWorkflowTransportActionTests.class.getName(), new ScalingExecutorBuilder( - WORKFLOW_THREAD_POOL, + DEPROVISION_WORKFLOW_THREAD_POOL, 1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), TimeValue.timeValueMinutes(5), - FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); private Client client; diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 8f9192639..ca955accc 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -44,6 +44,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -85,9 +86,16 @@ public void setUp() throws Exception { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), - TimeValue.timeValueMinutes(5), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + PROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(5), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ) ); this.deployModel = new DeployModelStep( diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index be822d26d..49aa4aaed 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -8,11 +8,11 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; -import org.opensearch.common.util.concurrent.UncategorizedExecutionException; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -25,9 +25,10 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -42,11 +43,11 @@ public static void setup() { testThreadPool = new TestThreadPool( ProcessNodeTests.class.getName(), new ScalingExecutorBuilder( - WORKFLOW_THREAD_POOL, + PROVISION_WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), TimeValue.timeValueMinutes(5), - FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ) ); @@ -89,6 +90,7 @@ public String getName() { new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id", "test-node-id"), List.of(successfulNode), testThreadPool, + PROVISION_WORKFLOW_THREAD_POOL, TimeValue.timeValueMillis(50) ); assertEquals("A", nodeA.id()); @@ -103,7 +105,7 @@ public String getName() { PlainActionFuture f = nodeA.execute(); assertEquals(f, nodeA.future()); - assertEquals("output", f.get().getContent().get("test")); + assertEquals("output", f.actionGet(1, TimeUnit.MINUTES).getContent().get("test")); } public void testNodeNoTimeout() throws InterruptedException, ExecutionException { @@ -117,7 +119,11 @@ public PlainActionFuture execute( Map previousNodeInputs ) { PlainActionFuture future = PlainActionFuture.newFuture(); - testThreadPool.schedule(() -> future.onResponse(WorkflowData.EMPTY), TimeValue.timeValueMillis(100), WORKFLOW_THREAD_POOL); + testThreadPool.schedule( + () -> future.onResponse(WorkflowData.EMPTY), + TimeValue.timeValueMillis(100), + PROVISION_WORKFLOW_THREAD_POOL + ); return future; } @@ -125,7 +131,14 @@ public PlainActionFuture execute( public String getName() { return "test"; } - }, Collections.emptyMap(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); + }, + Collections.emptyMap(), + WorkflowData.EMPTY, + Collections.emptyList(), + testThreadPool, + PROVISION_WORKFLOW_THREAD_POOL, + TimeValue.timeValueMillis(500) + ); assertEquals("B", nodeB.id()); assertEquals("test", nodeB.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeB.input()); @@ -134,7 +147,7 @@ public String getName() { PlainActionFuture f = nodeB.execute(); assertEquals(f, nodeB.future()); - assertEquals(WorkflowData.EMPTY, f.get()); + assertEquals(WorkflowData.EMPTY, f.actionGet(1, TimeUnit.MINUTES)); } public void testNodeTimeout() throws InterruptedException, ExecutionException { @@ -148,7 +161,11 @@ public PlainActionFuture execute( Map previousNodeInputs ) { PlainActionFuture future = PlainActionFuture.newFuture(); - testThreadPool.schedule(() -> future.onResponse(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), WORKFLOW_THREAD_POOL); + testThreadPool.schedule( + () -> future.onResponse(WorkflowData.EMPTY), + TimeValue.timeValueMinutes(1), + PROVISION_WORKFLOW_THREAD_POOL + ); return future; } @@ -156,7 +173,14 @@ public PlainActionFuture execute( public String getName() { return "sleepy"; } - }, Collections.emptyMap(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); + }, + Collections.emptyMap(), + WorkflowData.EMPTY, + Collections.emptyList(), + testThreadPool, + PROVISION_WORKFLOW_THREAD_POOL, + TimeValue.timeValueMillis(100) + ); assertEquals("Zzz", nodeZ.id()); assertEquals("sleepy", nodeZ.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeZ.input()); @@ -164,9 +188,9 @@ public String getName() { assertEquals("Zzz", nodeZ.toString()); PlainActionFuture f = nodeZ.execute(); - UncategorizedExecutionException exception = assertThrows(UncategorizedExecutionException.class, () -> f.actionGet()); + OpenSearchTimeoutException exception = assertThrows(OpenSearchTimeoutException.class, () -> f.actionGet()); assertTrue(f.isDone()); - assertEquals(ExecutionException.class, exception.getCause().getClass()); + assertEquals(TimeoutException.class, exception.getCause().getClass()); } public void testExceptions() { @@ -188,7 +212,14 @@ public PlainActionFuture execute( public String getName() { return "test"; } - }, Collections.emptyMap(), WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); + }, + Collections.emptyMap(), + WorkflowData.EMPTY, + List.of(successfulNode, failedNode), + testThreadPool, + PROVISION_WORKFLOW_THREAD_POOL, + TimeValue.timeValueSeconds(15) + ); assertEquals("E", nodeE.id()); assertEquals("test", nodeE.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeE.input()); @@ -196,7 +227,7 @@ public String getName() { assertEquals("E", nodeE.toString()); PlainActionFuture f = nodeE.execute(); - RuntimeException exception = assertThrows(RuntimeException.class, () -> f.actionGet()); + assertThrows(RuntimeException.class, () -> f.actionGet()); assertTrue(f.isDone()); // Tests where we already called execute assertThrows(IllegalStateException.class, () -> nodeE.execute()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 4e2d74865..e891d97ba 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -41,6 +41,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -81,9 +82,16 @@ public void setUp() throws Exception { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), - TimeValue.timeValueMinutes(5), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + PROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(5), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ) ); this.registerLocalModelStep = new RegisterLocalCustomModelStep( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index a46c292e9..8eb9d7798 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -41,6 +41,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -81,9 +82,16 @@ public void setUp() throws Exception { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), - TimeValue.timeValueMinutes(5), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + PROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(5), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ) ); this.registerLocalPretrainedModelStep = new RegisterLocalPretrainedModelStep( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index b548e09a9..6ca63b9de 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -41,6 +41,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -81,9 +82,16 @@ public void setUp() throws Exception { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), - TimeValue.timeValueMinutes(5), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ), + new ScalingExecutorBuilder( + PROVISION_WORKFLOW_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(5), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ) ); this.registerLocalSparseEncodingModelStep = new RegisterLocalSparseEncodingModelStep( diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index ef9dd5dcf..8825b0815 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -43,8 +43,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; @@ -112,11 +112,11 @@ public static void setup() throws IOException { testThreadPool = new TestThreadPool( WorkflowProcessSorterTests.class.getName(), new ScalingExecutorBuilder( - WORKFLOW_THREAD_POOL, + DEPROVISION_WORKFLOW_THREAD_POOL, 1, - OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), TimeValue.timeValueMinutes(5), - FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); WorkflowStepFactory factory = new WorkflowStepFactory( @@ -141,7 +141,7 @@ public void testNodeDetails() throws IOException { workflow = parseToNodes( workflow( List.of( - nodeWithType("default_timeout", "create_connector"), + nodeWithType("default_timeout", "noop"), nodeWithTypeAndTimeout("custom_timeout", "register_local_custom_model", "100ms") ), Collections.emptyList() @@ -149,7 +149,7 @@ public void testNodeDetails() throws IOException { ); ProcessNode node = workflow.get(0); assertEquals("default_timeout", node.id()); - assertEquals(CreateConnectorStep.class, node.workflowStep().getClass()); + assertEquals(NoOpStep.class, node.workflowStep().getClass()); assertEquals(10, node.nodeTimeout().seconds()); node = workflow.get(1); assertEquals("custom_timeout", node.id());