diff --git a/src/main/java/org/opensearch/flowframework/model/Workflow.java b/src/main/java/org/opensearch/flowframework/model/Workflow.java index 81f2677a7..7902f4f40 100644 --- a/src/main/java/org/opensearch/flowframework/model/Workflow.java +++ b/src/main/java/org/opensearch/flowframework/model/Workflow.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -119,13 +120,14 @@ public static Workflow parse(XContentParser parser) throws IOException { if (nodes.isEmpty()) { throw new IOException("A workflow must have at least one node."); } - if (edges.isEmpty()) { - // infer edges from sequence of nodes - // Start iteration at 1, will skip for a one-node array - for (int i = 1; i < nodes.size(); i++) { - edges.add(new WorkflowEdge(nodes.get(i - 1).id(), nodes.get(i).id())); - } - } + // Iterate the nodes and infer edges from previous node inputs + List inferredEdges = nodes.stream() + .flatMap(node -> node.previousNodeInputs().keySet().stream().map(previousNode -> new WorkflowEdge(previousNode, node.id()))) + .collect(Collectors.toList()); + // Remove any that are already in edges list + inferredEdges.removeAll(edges); + // Then add them to the edges + edges.addAll(inferredEdges); return new Workflow(userParams, nodes, edges); } diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java index ca5ee7a92..d6b41e371 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -17,6 +17,7 @@ import org.opensearch.flowframework.workflow.NoOpStep; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -53,6 +54,22 @@ public static String nodeWithTypeAndTimeout(String id, String type, String timeo + "\"}}"; } + public static String nodeWithTypeAndPreviousNodes(String id, String type, String... previousNodes) { + return "{\"" + + WorkflowNode.ID_FIELD + + "\": \"" + + id + + "\", \"" + + WorkflowNode.TYPE_FIELD + + "\": \"" + + type + + "\", \"" + + WorkflowNode.PREVIOUS_NODE_INPUTS_FIELD + + "\": {" + + Arrays.stream(previousNodes).map(n -> "\"" + n + "\": \"output_value\"").collect(Collectors.joining(",")) + + "}}"; + } + public static String edge(String sourceId, String destId) { return "{\"" + WorkflowEdge.SOURCE_FIELD + "\": \"" + sourceId + "\", \"" + WorkflowEdge.DEST_FIELD + "\": \"" + destId + "\"}"; } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 8fa2fa43d..058d2b898 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -140,7 +140,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { Workflow cyclicalWorkflow = new Workflow( originalWorkflow.userParams(), originalWorkflow.nodes(), - List.of(new WorkflowEdge("workflow_step_1", "workflow_step_2"), new WorkflowEdge("workflow_step_2", "workflow_step_1")) + List.of(new WorkflowEdge("workflow_step_2", "workflow_step_3"), new WorkflowEdge("workflow_step_3", "workflow_step_2")) ); Template cyclicalTemplate = new Template.Builder().name(template.name()) @@ -155,7 +155,10 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { // Hit dry run ResponseException exception = expectThrows(ResponseException.class, () -> createWorkflowValidation(cyclicalTemplate)); - assertTrue(exception.getMessage().contains("Cycle detected: [workflow_step_2->workflow_step_1, workflow_step_1->workflow_step_2]")); + // output order not guaranteed + assertTrue(exception.getMessage().contains("Cycle detected")); + assertTrue(exception.getMessage().contains("workflow_step_2->workflow_step_3")); + assertTrue(exception.getMessage().contains("workflow_step_3->workflow_step_2")); // Hit Create Workflow API with original template Response response = createWorkflow(template); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index ef467b8dd..706771b93 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -24,7 +24,6 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -59,6 +58,7 @@ import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndPreviousNodes; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; import static org.mockito.ArgumentMatchers.any; @@ -72,11 +72,14 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase { private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; private static final String CYCLE_DETECTED = "Cycle detected:"; + // Wrap parser into workflow + private static Workflow parseToWorkflow(String json) throws IOException { + return Workflow.parse(TemplateTestJsonUtil.jsonToParser(json)); + } + // Wrap parser into node list private static List parseToNodes(String json) throws IOException { - XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); - Workflow w = Workflow.parse(parser); - return workflowProcessSorter.sortProcessNodes(w, "123"); + return workflowProcessSorter.sortProcessNodes(parseToWorkflow(json), "123"); } // Wrap parser into string list @@ -242,6 +245,56 @@ public void testNoEdges() throws IOException { assertTrue(workflow.contains("B")); } + public void testInferredEdges() throws IOException { + Workflow w = parseToWorkflow( + workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), Collections.emptyList()) + ); + assertTrue(w.edges().isEmpty()); + + w = parseToWorkflow( + workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), List.of(edge("B", "A"))) + ); + // edge from previous inputs only + assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges()); + + w = parseToWorkflow( + workflow( + List.of(nodeWithTypeAndPreviousNodes("A", "noop", "B"), nodeWithTypeAndPreviousNodes("B", "noop")), + Collections.emptyList() + ) + ); + // edge from edges only + assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges()); + + w = parseToWorkflow( + workflow( + List.of( + nodeWithTypeAndPreviousNodes("A", "noop", "B"), + nodeWithTypeAndPreviousNodes("B", "noop"), + nodeWithTypeAndPreviousNodes("C", "noop") + ), + List.of(edge("C", "A")) + ) + ); + // combine sources, order not guaranteed + assertEquals(2, w.edges().size()); + assertTrue(w.edges().contains(new WorkflowEdge("B", "A"))); + assertTrue(w.edges().contains(new WorkflowEdge("C", "A"))); + + w = parseToWorkflow( + workflow( + List.of( + nodeWithTypeAndPreviousNodes("A", "noop", "B"), + nodeWithTypeAndPreviousNodes("B", "noop"), + nodeWithTypeAndPreviousNodes("C", "noop") + ), + List.of(edge("B", "A")) + ) + ); + // duplicates, only 1 + assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges()); + } + public void testExceptions() throws IOException { Exception ex = assertThrows( FlowFrameworkException.class,