Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Infer edges from previous node inputs #335

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/main/java/org/opensearch/flowframework/model/Workflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand Down Expand Up @@ -119,13 +120,14 @@ public static Workflow parse(XContentParser parser) throws IOException {
if (nodes.isEmpty()) {
throw new IOException("A workflow must have at least one node.");
}
if (edges.isEmpty()) {
// infer edges from sequence of nodes
// Start iteration at 1, will skip for a one-node array
for (int i = 1; i < nodes.size(); i++) {
edges.add(new WorkflowEdge(nodes.get(i - 1).id(), nodes.get(i).id()));
}
}
// Iterate the nodes and infer edges from previous node inputs
List<WorkflowEdge> inferredEdges = nodes.stream()
.flatMap(node -> node.previousNodeInputs().keySet().stream().map(previousNode -> new WorkflowEdge(previousNode, node.id())))
.collect(Collectors.toList());
// Remove any that are already in edges list
inferredEdges.removeAll(edges);
// Then add them to the edges
edges.addAll(inferredEdges);
return new Workflow(userParams, nodes, edges);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.flowframework.workflow.NoOpStep;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -53,6 +54,22 @@ public static String nodeWithTypeAndTimeout(String id, String type, String timeo
+ "\"}}";
}

public static String nodeWithTypeAndPreviousNodes(String id, String type, String... previousNodes) {
return "{\""
+ WorkflowNode.ID_FIELD
+ "\": \""
+ id
+ "\", \""
+ WorkflowNode.TYPE_FIELD
+ "\": \""
+ type
+ "\", \""
+ WorkflowNode.PREVIOUS_NODE_INPUTS_FIELD
+ "\": {"
+ Arrays.stream(previousNodes).map(n -> "\"" + n + "\": \"output_value\"").collect(Collectors.joining(","))
+ "}}";
}

public static String edge(String sourceId, String destId) {
return "{\"" + WorkflowEdge.SOURCE_FIELD + "\": \"" + sourceId + "\", \"" + WorkflowEdge.DEST_FIELD + "\": \"" + destId + "\"}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {
Workflow cyclicalWorkflow = new Workflow(
originalWorkflow.userParams(),
originalWorkflow.nodes(),
List.of(new WorkflowEdge("workflow_step_1", "workflow_step_2"), new WorkflowEdge("workflow_step_2", "workflow_step_1"))
List.of(new WorkflowEdge("workflow_step_2", "workflow_step_3"), new WorkflowEdge("workflow_step_3", "workflow_step_2"))
);

Template cyclicalTemplate = new Template.Builder().name(template.name())
Expand All @@ -155,7 +155,10 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {

// Hit dry run
ResponseException exception = expectThrows(ResponseException.class, () -> createWorkflowValidation(cyclicalTemplate));
assertTrue(exception.getMessage().contains("Cycle detected: [workflow_step_2->workflow_step_1, workflow_step_1->workflow_step_2]"));
// output order not guaranteed
assertTrue(exception.getMessage().contains("Cycle detected"));
assertTrue(exception.getMessage().contains("workflow_step_2->workflow_step_3"));
assertTrue(exception.getMessage().contains("workflow_step_3->workflow_step_2"));

// Hit Create Workflow API with original template
Response response = createWorkflow(template);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand Down Expand Up @@ -59,6 +58,7 @@
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndPreviousNodes;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -72,11 +72,14 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase {
private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor.";
private static final String CYCLE_DETECTED = "Cycle detected:";

// Wrap parser into workflow
private static Workflow parseToWorkflow(String json) throws IOException {
return Workflow.parse(TemplateTestJsonUtil.jsonToParser(json));
}

// Wrap parser into node list
private static List<ProcessNode> parseToNodes(String json) throws IOException {
XContentParser parser = TemplateTestJsonUtil.jsonToParser(json);
Workflow w = Workflow.parse(parser);
return workflowProcessSorter.sortProcessNodes(w, "123");
return workflowProcessSorter.sortProcessNodes(parseToWorkflow(json), "123");
}

// Wrap parser into string list
Expand Down Expand Up @@ -242,6 +245,56 @@ public void testNoEdges() throws IOException {
assertTrue(workflow.contains("B"));
}

public void testInferredEdges() throws IOException {
Workflow w = parseToWorkflow(
workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), Collections.emptyList())
);
assertTrue(w.edges().isEmpty());

w = parseToWorkflow(
workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), List.of(edge("B", "A")))
);
// edge from previous inputs only
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());

w = parseToWorkflow(
workflow(
List.of(nodeWithTypeAndPreviousNodes("A", "noop", "B"), nodeWithTypeAndPreviousNodes("B", "noop")),
Collections.emptyList()
)
);
// edge from edges only
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());

w = parseToWorkflow(
workflow(
List.of(
nodeWithTypeAndPreviousNodes("A", "noop", "B"),
nodeWithTypeAndPreviousNodes("B", "noop"),
nodeWithTypeAndPreviousNodes("C", "noop")
),
List.of(edge("C", "A"))
)
);
// combine sources, order not guaranteed
assertEquals(2, w.edges().size());
assertTrue(w.edges().contains(new WorkflowEdge("B", "A")));
assertTrue(w.edges().contains(new WorkflowEdge("C", "A")));

w = parseToWorkflow(
workflow(
List.of(
nodeWithTypeAndPreviousNodes("A", "noop", "B"),
nodeWithTypeAndPreviousNodes("B", "noop"),
nodeWithTypeAndPreviousNodes("C", "noop")
),
List.of(edge("B", "A"))
)
);
// duplicates, only 1
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());
}

public void testExceptions() throws IOException {
Exception ex = assertThrows(
FlowFrameworkException.class,
Expand Down
Loading