diff --git a/.codecov.yml b/.codecov.yml index e5bbd7262..827160da7 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,5 +1,5 @@ codecov: - require_ci_to_pass: yes + require_ci_to_pass: true # ignore files in demo package ignore: @@ -12,5 +12,8 @@ coverage: status: project: default: - target: 70% # the required coverage value - threshold: 1% # the leniency in hitting the target + target: auto + threshold: 2% # project coverage can drop + patch: + default: + target: 70% # required diff coverage value diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 53cf3499c..12bd6925d 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -14,10 +14,12 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,8 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; /** @@ -37,6 +38,8 @@ public class Demo { private static final Logger logger = LogManager.getLogger(Demo.class); + private Demo() {} + /** * Demonstrate parsing a JSON graph. * @@ -54,13 +57,14 @@ public static void main(String[] args) throws IOException { return; } Client client = new NodeClient(null, null); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); - ExecutorService executor = Executors.newFixedThreadPool(10); - WorkflowProcessSorter.create(factory, executor); + WorkflowStepFactory factory = new WorkflowStepFactory(client); + + ThreadPool threadPool = new ThreadPool(Settings.EMPTY); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); logger.info("Parsing graph to sequence..."); Template t = Template.parse(json); - List processSequence = WorkflowProcessSorter.get().sortProcessNodes(t.workflows().get("demo")); + List processSequence = workflowProcessSorter.sortProcessNodes(t.workflows().get("demo")); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { @@ -80,6 +84,6 @@ public static void main(String[] args) throws IOException { } futureList.forEach(CompletableFuture::join); logger.info("All done!"); - executor.shutdown(); + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } } diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index 307d707c0..dbe338217 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -14,16 +14,18 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.Map.Entry; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; /** * Demo class exercising {@link WorkflowProcessSorter}. This will be moved to a unit test. @@ -32,6 +34,8 @@ public class TemplateParseDemo { private static final Logger logger = LogManager.getLogger(TemplateParseDemo.class); + private TemplateParseDemo() {} + /** * Demonstrate parsing a JSON graph. * @@ -49,8 +53,9 @@ public static void main(String[] args) throws IOException { return; } Client client = new NodeClient(null, null); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); - WorkflowProcessSorter.create(factory, Executors.newFixedThreadPool(10)); + WorkflowStepFactory factory = new WorkflowStepFactory(client); + ThreadPool threadPool = new ThreadPool(Settings.EMPTY); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); Template t = Template.parse(json); @@ -59,7 +64,8 @@ public static void main(String[] args) throws IOException { for (Entry e : t.workflows().entrySet()) { logger.info("Parsing {} workflow.", e.getKey()); - WorkflowProcessSorter.get().sortProcessNodes(e.getValue()); + workflowProcessSorter.sortProcessNodes(e.getValue()); } + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } } diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index d701c832e..853c138db 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -32,6 +32,11 @@ */ public class FlowFrameworkPlugin extends Plugin { + /** + * Instantiate this plugin. + */ + public FlowFrameworkPlugin() {} + @Override public Collection createComponents( Client client, @@ -46,8 +51,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = WorkflowStepFactory.create(client); - WorkflowProcessSorter workflowProcessSorter = WorkflowProcessSorter.create(workflowStepFactory, threadPool.generic()); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(client); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter); } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index b48b6e0d2..8c4a6ae52 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -41,6 +41,10 @@ public class WorkflowNode implements ToXContentObject { public static final String INPUTS_FIELD = "inputs"; /** The field defining processors in the inputs for search and ingest pipelines */ public static final String PROCESSORS_FIELD = "processors"; + /** The field defining the timeout value for this node */ + public static final String NODE_TIMEOUT_FIELD = "node_timeout"; + /** The default timeout value if the template doesn't override it */ + public static final String NODE_TIMEOUT_DEFAULT_VALUE = "10s"; private final String id; // unique id private final String type; // maps to a WorkflowStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index a2d7628c3..2f902755c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -10,17 +10,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.Scheduler.ScheduledCancellable; +import org.opensearch.threadpool.ThreadPool; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; /** - * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can start execution. + * Representation of a process node in a workflow graph. + * Tracks predecessor nodes which must be completed before it can start execution. */ public class ProcessNode { @@ -30,7 +32,8 @@ public class ProcessNode { private final WorkflowStep workflowStep; private final WorkflowData input; private final List predecessors; - private Executor executor; + private final ThreadPool threadPool; + private final TimeValue nodeTimeout; private final CompletableFuture future = new CompletableFuture<>(); @@ -41,14 +44,23 @@ public class ProcessNode { * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow - * @param executor The OpenSearch thread pool + * @param threadPool The OpenSearch thread pool + * @param nodeTimeout The timeout value for executing on this node */ - public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input, List predecessors, Executor executor) { + public ProcessNode( + String id, + WorkflowStep workflowStep, + WorkflowData input, + List predecessors, + ThreadPool threadPool, + TimeValue nodeTimeout + ) { this.id = id; this.workflowStep = workflowStep; this.input = input; this.predecessors = predecessors; - this.executor = executor; + this.threadPool = threadPool; + this.nodeTimeout = nodeTimeout; } /** @@ -90,64 +102,73 @@ public CompletableFuture future() { * Returns the predecessors of this node in the workflow. * The predecessor's {@link #future()} must complete before execution begins on this node. * - * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. + * @return a set of predecessor nodes, if any. + * At least one node in the graph must have no predecessors and serve as a start node. */ public List predecessors() { return predecessors; } /** - * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. + * Returns the timeout value of this node in the workflow. A value of {@link TimeValue#ZERO} means no timeout. + * @return The node's timeout value. + */ + public TimeValue nodeTimeout() { + return nodeTimeout; + } + + /** + * Execute this node in the sequence. + * Initializes the node's {@link CompletableFuture} and completes it when the process completes. * - * @return this node's future. This is returned immediately, while process execution continues asynchronously. + * @return this node's future. + * This is returned immediately, while process execution continues asynchronously. */ public CompletableFuture execute() { - // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) - // the generic executor from that pool should be passed to this runAsync call - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + if (this.future.isDone()) { + throw new IllegalStateException("Process Node [" + this.id + "] already executed."); + } CompletableFuture.runAsync(() -> { List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); - if (!predecessors.isEmpty()) { - CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); - try { - // We need timeouts to be part of the user template or in settings - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/45 - waitForPredecessors.orTimeout(30, TimeUnit.SECONDS).get(); - } catch (InterruptedException | ExecutionException e) { - handleException(e); - return; + try { + if (!predecessors.isEmpty()) { + CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); + waitForPredecessors.join(); } - } - logger.info(">>> Starting {}.", this.id); - // get the input data from predecessor(s) - List input = new ArrayList(); - input.add(this.input); - for (CompletableFuture cf : predFutures) { - try { + + logger.info("Starting {}.", this.id); + // get the input data from predecessor(s) + List input = new ArrayList(); + input.add(this.input); + for (CompletableFuture cf : predFutures) { input.add(cf.get()); - } catch (InterruptedException | ExecutionException e) { - handleException(e); - return; } - } - CompletableFuture stepFuture = this.workflowStep.execute(input); - try { - stepFuture.orTimeout(15, TimeUnit.SECONDS).join(); - logger.info(">>> Finished {}.", this.id); + + ScheduledCancellable delayExec = null; + if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { + delayExec = threadPool.schedule(() -> { + if (!future.isDone()) { + future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); + } + }, this.nodeTimeout, ThreadPool.Names.SAME); + } + CompletableFuture stepFuture = this.workflowStep.execute(input); + // If completed exceptionally, this is a no-op future.complete(stepFuture.get()); - } catch (InterruptedException | ExecutionException e) { - handleException(e); + if (delayExec != null) { + delayExec.cancel(); + } + logger.info("Finished {}.", this.id); + } catch (Throwable e) { + // TODO: better handling of getCause + this.future.completeExceptionally(e); } - }, executor); + // TODO: improve use of thread pool beyond generic + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61 + }, threadPool.generic()); return this.future; } - private void handleException(Exception e) { - // TODO: better handling of getCause - this.future.completeExceptionally(e); - logger.debug("<<< Completed Exceptionally {}", this.id, e.getCause()); - } - @Override public String toString() { return this.id; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3370f6384..71c44514e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,9 +10,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; import java.util.ArrayList; @@ -23,52 +25,32 @@ import java.util.Map; import java.util.Queue; import java.util.Set; -import java.util.concurrent.Executor; import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.flowframework.model.WorkflowNode.INPUTS_FIELD; +import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; +import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; + /** - * Utility class converting a workflow of nodes and edges into a topologically sorted list of Process Nodes. + * Converts a workflow of nodes and edges into a topologically sorted list of Process Nodes. */ public class WorkflowProcessSorter { private static final Logger logger = LogManager.getLogger(WorkflowProcessSorter.class); - private static WorkflowProcessSorter instance = null; - private WorkflowStepFactory workflowStepFactory; - private Executor executor; - - /** - * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. - * - * @param workflowStepFactory The singleton instance of {@link WorkflowStepFactory} - * @param executor A thread executor - * @return The created instance - */ - public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, Executor executor) { - if (instance != null) { - throw new IllegalStateException("This class was already created."); - } - instance = new WorkflowProcessSorter(workflowStepFactory, executor); - return instance; - } + private ThreadPool threadPool; /** - * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. + * Instantiate this class. * - * @return The created instance + * @param workflowStepFactory The factory which matches template step types to instances. + * @param threadPool The OpenSearch Thread pool to pass to process nodes. */ - public static synchronized WorkflowProcessSorter get() { - if (instance == null) { - throw new IllegalStateException("This factory has not yet been created."); - } - return instance; - } - - private WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, Executor executor) { + public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { this.workflowStepFactory = workflowStepFactory; - this.executor = executor; + this.threadPool = threadPool; } /** @@ -91,7 +73,8 @@ public List sortProcessNodes(Workflow workflow) { .map(e -> idToNodeMap.get(e.source())) .collect(Collectors.toList()); - ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, executor); + TimeValue nodeTimeout = parseTimeout(node); + ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, threadPool, nodeTimeout); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); } @@ -99,6 +82,18 @@ public List sortProcessNodes(Workflow workflow) { return nodes; } + private TimeValue parseTimeout(WorkflowNode node) { + String timeoutValue = (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); + String fieldName = String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD); + TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); + if (timeValue.millis() < 0) { + throw new IllegalArgumentException( + "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive" + ); + } + return timeValue; + } + private static List topologicalSort(List workflowNodes, List workflowEdges) { // Basic validation Set nodeIds = workflowNodes.stream().map(n -> n.id()).collect(Collectors.toSet()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 6cd5f5a28..41e627016 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -24,7 +24,7 @@ public interface WorkflowStep { CompletableFuture execute(List data); /** - * + * Gets the name of the workflow step. * @return the name of this workflow step. */ String getName(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index dc0dc29a2..26dab0f42 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -22,37 +22,14 @@ */ public class WorkflowStepFactory { - private static WorkflowStepFactory instance = null; - private final Map stepMap = new HashMap<>(); /** - * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. + * Instantiate this class. * * @param client The OpenSearch client steps can use - * @return The created instance - */ - public static synchronized WorkflowStepFactory create(Client client) { - if (instance != null) { - throw new IllegalStateException("This factory was already created."); - } - instance = new WorkflowStepFactory(client); - return instance; - } - - /** - * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. - * - * @return The created instance */ - public static synchronized WorkflowStepFactory get() { - if (instance == null) { - throw new IllegalStateException("This factory has not yet been created."); - } - return instance; - } - - private WorkflowStepFactory(Client client) { + public WorkflowStepFactory(Client client) { populateMap(client); } @@ -62,7 +39,7 @@ private void populateMap(Client client) { // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); - stepMap.put("demo_delay_5", new DemoWorkflowStep(3000)); + stepMap.put("demo_delay_5", new DemoWorkflowStep(5000)); // Use as a default until all the actual implementations are ready stepMap.put("placeholder", new WorkflowStep() { diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 9f7075d19..d211e3928 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -8,8 +8,40 @@ */ package org.opensearch.flowframework; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class FlowFrameworkPluginTests extends OpenSearchTestCase { - // Add unit tests for your plugin + + private Client client; + private ThreadPool threadPool; + + @Override + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + when(client.admin()).thenReturn(mock(AdminClient.class)); + threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + } + + @Override + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); + super.tearDown(); + } + + public void testPlugin() throws IOException { + try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { + assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + } + } } diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java index 247521084..b38346b29 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -27,7 +27,29 @@ public class TemplateTestJsonUtil { public static String node(String id) { - return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + "placeholder" + "\"}"; + return nodeWithType(id, "placeholder"); + } + + public static String nodeWithType(String id, String type) { + return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + type + "\"}"; + } + + public static String nodeWithTypeAndTimeout(String id, String type, String timeout) { + return "{\"" + + WorkflowNode.ID_FIELD + + "\": \"" + + id + + "\", \"" + + WorkflowNode.TYPE_FIELD + + "\": \"" + + type + + "\", \"" + + WorkflowNode.INPUTS_FIELD + + "\": {\"" + + WorkflowNode.NODE_TIMEOUT_FIELD + + "\": \"" + + timeout + + "\"}}"; } public static String edge(String sourceId, String destId) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 1972d20eb..1e421c58c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -8,37 +8,57 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; -import org.junit.After; -import org.junit.Before; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.BeforeClass; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class ProcessNodeTests extends OpenSearchTestCase { - private ExecutorService executor; + private static TestThreadPool testThreadPool; + private static ProcessNode successfulNode; + private static ProcessNode failedNode; + + @BeforeClass + public static void setup() { + testThreadPool = new TestThreadPool(ProcessNodeTests.class.getName()); - @Before - public void setup() { - executor = Executors.newFixedThreadPool(10); + CompletableFuture successfulFuture = new CompletableFuture<>(); + successfulFuture.complete(WorkflowData.EMPTY); + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Test exception")); + successfulNode = mock(ProcessNode.class); + when(successfulNode.future()).thenReturn(successfulFuture); + failedNode = mock(ProcessNode.class); + when(failedNode.future()).thenReturn(failedFuture); } - @After - public void cleanup() { - executor.shutdown(); + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testNode() throws InterruptedException, ExecutionException { + // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override public CompletableFuture execute(List data) { CompletableFuture f = new CompletableFuture<>(); - f.complete(WorkflowData.EMPTY); + f.complete(new WorkflowData(Map.of("test", "output"))); return f; } @@ -46,15 +66,109 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), executor); + }, + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), + List.of(successfulNode), + testThreadPool, + TimeValue.timeValueMillis(50) + ); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); - assertEquals(WorkflowData.EMPTY, nodeA.input()); - assertEquals(Collections.emptyList(), nodeA.predecessors()); + assertEquals("input", nodeA.input().getContent().get("test")); + assertEquals("bar", nodeA.input().getParams().get("foo")); + assertEquals(1, nodeA.predecessors().size()); + assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); CompletableFuture f = nodeA.execute(); assertEquals(f, nodeA.future()); + assertEquals("output", f.get().getContent().get("test")); + } + + public void testNodeNoTimeout() throws InterruptedException, ExecutionException { + // Tests where execute finishes before timeout + ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + testThreadPool.schedule( + () -> future.complete(WorkflowData.EMPTY), + TimeValue.timeValueMillis(100), + ThreadPool.Names.GENERIC + ); + return future; + } + + @Override + public String getName() { + return "test"; + } + }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); + assertEquals("B", nodeB.id()); + assertEquals("test", nodeB.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeB.input()); + assertEquals(Collections.emptyList(), nodeB.predecessors()); + assertEquals("B", nodeB.toString()); + + CompletableFuture f = nodeB.execute(); + assertEquals(f, nodeB.future()); assertEquals(WorkflowData.EMPTY, f.get()); } + + public void testNodeTimeout() throws InterruptedException, ExecutionException { + // Tests where execute finishes after timeout + ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC); + return future; + } + + @Override + public String getName() { + return "sleepy"; + } + }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); + assertEquals("Zzz", nodeZ.id()); + assertEquals("sleepy", nodeZ.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeZ.input()); + assertEquals(Collections.emptyList(), nodeZ.predecessors()); + assertEquals("Zzz", nodeZ.toString()); + + CompletableFuture f = nodeZ.execute(); + CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); + assertTrue(f.isCompletedExceptionally()); + assertEquals(TimeoutException.class, exception.getCause().getClass()); + } + + public void testExceptions() { + // Tests where a predecessor future completed exceptionally + ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture f = new CompletableFuture<>(); + f.complete(WorkflowData.EMPTY); + return f; + } + + @Override + public String getName() { + return "test"; + } + }, WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); + assertEquals("E", nodeE.id()); + assertEquals("test", nodeE.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeE.input()); + assertEquals(2, nodeE.predecessors().size()); + assertEquals("E", nodeE.toString()); + + CompletableFuture f = nodeE.execute(); + CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); + assertTrue(f.isCompletedExceptionally()); + assertEquals("Test exception", exception.getCause().getMessage()); + + // Tests where we already called execute + assertThrows(IllegalStateException.class, () -> nodeE.execute()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 1e9c8e808..74240d561 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,18 +14,21 @@ import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; import org.junit.BeforeClass; import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -36,14 +39,19 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase { private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; private static final String CYCLE_DETECTED = "Cycle detected:"; - // Wrap parser into string list - private static List parse(String json) throws IOException { + // Wrap parser into node list + private static List parseToNodes(String json) throws IOException { XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); Workflow w = Workflow.parse(parser); - return workflowProcessSorter.sortProcessNodes(w).stream().map(ProcessNode::id).collect(Collectors.toList()); + return workflowProcessSorter.sortProcessNodes(w); + } + + // Wrap parser into string list + private static List parse(String json) throws IOException { + return parseToNodes(json).stream().map(ProcessNode::id).collect(Collectors.toList()); } - private static ExecutorService executor; + private static TestThreadPool testThreadPool; private static WorkflowProcessSorter workflowProcessSorter; @BeforeClass @@ -52,14 +60,35 @@ public static void setup() { Client client = mock(Client.class); when(client.admin()).thenReturn(adminClient); - executor = Executors.newFixedThreadPool(10); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); - workflowProcessSorter = WorkflowProcessSorter.create(factory, executor); + testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); + WorkflowStepFactory factory = new WorkflowStepFactory(client); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } @AfterClass public static void cleanup() { - executor.shutdown(); + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + public void testNodeDetails() throws IOException { + List workflow = null; + workflow = parseToNodes( + workflow( + List.of( + nodeWithType("default_timeout", "create_ingest_pipeline"), + nodeWithTypeAndTimeout("custom_timeout", "create_index", "100ms") + ), + Collections.emptyList() + ) + ); + ProcessNode node = workflow.get(0); + assertEquals("default_timeout", node.id()); + assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass()); + assertEquals(10, node.nodeTimeout().seconds()); + node = workflow.get(1); + assertEquals("custom_timeout", node.id()); + assertEquals(CreateIndexStep.class, node.workflowStep().getClass()); + assertEquals(100, node.nodeTimeout().millis()); } public void testOrdering() throws IOException { diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index d8443c4c6..fe1a57e36 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -25,7 +25,8 @@ "type": "create_index", "inputs": { "name": "user_inputs.index_name", - "settings": "user_inputs.index_settings" + "settings": "user_inputs.index_settings", + "node_timeout": "10s" } }, { @@ -41,7 +42,8 @@ "input_field": "text_passage", "output_field": "text_embedding" } - }] + }], + "node_timeout": "10s" } } ], @@ -60,7 +62,8 @@ "inputs": { "index": "user_inputs.index_name", "ingest_pipeline": "my-ingest-pipeline", - "document": "user_params.document" + "document": "user_params.document", + "node_timeout": "10s" } }] }, @@ -73,7 +76,8 @@ "type": "transform_query", "inputs": { "template": "neural-search-template-1", - "plaintext": "user_params.plaintext" + "plaintext": "user_params.plaintext", + "node_timeout": "10s" } }, { @@ -83,7 +87,8 @@ "index": "user_inputs.index_name", "query": "{{output-from-prev-step}}.query", "search_request_processors": [], - "search_response_processors": [] + "search_response_processors": [], + "node_timeout": "10s" } } ],