diff --git a/build.gradle b/build.gradle index fcf08e16d..7664db368 100644 --- a/build.gradle +++ b/build.gradle @@ -89,6 +89,8 @@ repositories { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" + + implementation "com.google.code.gson:gson:2.10.1" compileOnly "com.google.guava:guava:32.1.2-jre" configurations.all { diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java new file mode 100644 index 000000000..e1b57fc51 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +public class ProcessNode { + private final String id; + private CompletableFuture future; + + // will be populated during graph parsing + private Set predecessors = Collections.emptySet(); + + ProcessNode(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public CompletableFuture getFuture() { + return future; + } + + public Set getPredecessors() { + return predecessors; + } + + public void setPredecessors(Set predecessors) { + this.predecessors = Set.copyOf(predecessors); + } + + public CompletableFuture execute() { + this.future = new CompletableFuture<>(); + CompletableFuture.runAsync(() -> { + if (!predecessors.isEmpty()) { + List> predFutures = predecessors.stream().map(p -> p.getFuture()).collect(Collectors.toList()); + CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); + try { + waitForPredecessors.orTimeout(30, TimeUnit.SECONDS).get(); + } catch (InterruptedException | ExecutionException e) { + future.completeExceptionally(e); + } + } + if (future.isCompletedExceptionally()) { + return; + } + System.out.println(">>> Starting " + this.id); + sleep(id.contains("ingest") ? 8000 : 4000); + System.out.println("<<< Finished " + this.id); + future.complete(this.id); + }); + return this.future; + } + + private void sleep(long i) { + try { + Thread.sleep(i); + } catch (InterruptedException e) {} + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessNode other = (ProcessNode) obj; + return Objects.equals(id, other.id); + } + + @Override + public String toString() { + return this.id; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java new file mode 100644 index 000000000..4d768958f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import java.util.Objects; + +public class ProcessSequenceEdge { + private final String source; + private final String destination; + + ProcessSequenceEdge(String source, String destination) { + this.source = source; + this.destination = destination; + } + + public String getSource() { + return source; + } + + public String getDestination() { + return destination; + } + + @Override + public int hashCode() { + return Objects.hash(destination, source); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessSequenceEdge other = (ProcessSequenceEdge) obj; + return Objects.equals(destination, other.destination) && Objects.equals(source, other.source); + } + + @Override + public String toString() { + return this.source + "->" + this.destination; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java new file mode 100644 index 000000000..9d125d118 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class TemplateParser { + + public static void main(String[] args) { + String json = "{\n" + + " \"sequence\": {\n" + + " \"nodes\": [\n" + + " {\n" + + " \"id\": \"fetch_model\"\n" + + " },\n" + + " {\n" + + " \"id\": \"create_ingest_pipeline\"\n" + + " },\n" + + " {\n" + + " \"id\": \"create_search_pipeline\"\n" + + " },\n" + + " {\n" + + " \"id\": \"create_neural_search_index\"\n" + + " }\n" + + " ],\n" + + " \"edges\": [\n" + + " {\n" + + " \"source\": \"fetch_model\",\n" + + " \"dest\": \"create_ingest_pipeline\"\n" + + " },\n" + + " {\n" + + " \"source\": \"fetch_model\",\n" + + " \"dest\": \"create_search_pipeline\"\n" + + " },\n" + + " {\n" + + " \"source\": \"create_ingest_pipeline\",\n" + + " \"dest\": \"create_neural_search_index\"\n" + + " },\n" + + " {\n" + + " \"source\": \"create_search_pipeline\",\n" + + " \"dest\": \"create_neural_search_index\"\n" + // + " }\n," + // + " {\n" + // + " \"source\": \"create_neural_search_index\",\n" + // + " \"dest\": \"fetch_model\"\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + System.out.println(json); + + System.out.println("Parsing graph to sequence..."); + List processSequence = parseJsonGraphToSequence(json); + List> futureList = new ArrayList<>(); + + for (ProcessNode n : processSequence) { + Set predecessors = n.getPredecessors(); + System.out.format( + "Queueing process [%s]. %s.%n", + n.getId(), + predecessors.isEmpty() + ? "Can start immediately!" + : String.format( + "Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.getId()).collect(Collectors.joining(", ")) + ) + ); + futureList.add(n.execute()); + } + futureList.forEach(CompletableFuture::join); + System.out.println("All done!"); + } + + private static List parseJsonGraphToSequence(String json) { + Gson gson = new Gson(); + JsonObject jsonObject = gson.fromJson(json, JsonObject.class); + + JsonObject graph = jsonObject.getAsJsonObject("sequence"); + + List nodes = new ArrayList<>(); + List edges = new ArrayList<>(); + + for (JsonElement nodeJson : graph.getAsJsonArray("nodes")) { + JsonObject nodeObject = nodeJson.getAsJsonObject(); + String nodeId = nodeObject.get("id").getAsString(); + nodes.add(new ProcessNode(nodeId)); + } + + for (JsonElement edgeJson : graph.getAsJsonArray("edges")) { + JsonObject edgeObject = edgeJson.getAsJsonObject(); + String sourceNodeId = edgeObject.get("source").getAsString(); + String destNodeId = edgeObject.get("dest").getAsString(); + edges.add(new ProcessSequenceEdge(sourceNodeId, destNodeId)); + } + + return topologicalSort(nodes, edges); + } + + private static List topologicalSort(List nodes, List edges) { + // Define the graph + Set graph = new HashSet<>(edges); + // Map node id string to object + Map nodeMap = nodes.stream().collect(Collectors.toMap(ProcessNode::getId, Function.identity())); + // Build predecessor and successor maps + Map> predecessorEdges = new HashMap<>(); + Map> successorEdges = new HashMap<>(); + for (ProcessSequenceEdge edge : edges) { + ProcessNode source = nodeMap.get(edge.getSource()); + ProcessNode dest = nodeMap.get(edge.getDestination()); + predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); + successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); + } + // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + // Find start node(s) which have no predecessors + Queue sourceNodes = new ArrayDeque<>(); + nodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); + if (sourceNodes.isEmpty()) { + throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + } + System.out.println("Start node(s): " + sourceNodes); + + // List to contain sorted elements + List sortedNodes = new ArrayList<>(); + // Keep adding successors + while (!sourceNodes.isEmpty()) { + ProcessNode n = sourceNodes.poll(); + sortedNodes.add(n); + if (predecessorEdges.containsKey(n)) { + n.setPredecessors(predecessorEdges.get(n).stream().map(e -> nodeMap.get(e.getSource())).collect(Collectors.toSet())); + } + // Add successors to the queue + for (ProcessSequenceEdge e : successorEdges.getOrDefault(n, Collections.emptySet())) { + graph.remove(e); + ProcessNode dest = nodeMap.get(e.getDestination()); + if (!sourceNodes.contains(dest) && !sortedNodes.contains(dest)) { + sourceNodes.add(dest); + } + } + } + if (!graph.isEmpty()) { + throw new IllegalArgumentException("Cycle detected: " + graph); + } + System.out.println("Execution sequence: " + sortedNodes); + return sortedNodes; + } +}