diff --git a/build.gradle b/build.gradle index c3b193b15..2f6d80aa9 100644 --- a/build.gradle +++ b/build.gradle @@ -106,7 +106,6 @@ repositories { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" implementation 'org.junit.jupiter:junit-jupiter:5.10.0' - implementation "com.google.code.gson:gson:2.10.1" implementation "com.google.guava:guava:32.1.2-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" diff --git a/src/main/java/demo/CreateIndexWorkflowStep.java b/src/main/java/demo/CreateIndexWorkflowStep.java deleted file mode 100644 index 6b2ab0a7b..000000000 --- a/src/main/java/demo/CreateIndexWorkflowStep.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package demo; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; - -/** - * Sample to show other devs how to pass data around. Will be deleted once other PRs are merged. - */ -public class CreateIndexWorkflowStep implements WorkflowStep { - - private static final Logger logger = LogManager.getLogger(CreateIndexWorkflowStep.class); - - private final String name; - - /** - * Instantiate this class. - */ - public CreateIndexWorkflowStep() { - this.name = "CREATE_INDEX"; - } - - @Override - public CompletableFuture execute(List data) { - CompletableFuture future = new CompletableFuture<>(); - // TODO we will be passing a thread pool to this object when it's instantiated - // we should either add the generic executor from that pool to this call - // or use executorservice.submit or any of various threading options - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 - CompletableFuture.runAsync(() -> { - String inputIndex = null; - boolean first = true; - for (WorkflowData wfData : data) { - logger.debug( - "{} sent params: {}, content: {}", - first ? "Initialization" : "Previous step", - wfData.getParams(), - wfData.getContent() - ); - if (first) { - Map params = data.get(0).getParams(); - if (params.containsKey("index")) { - inputIndex = params.get("index"); - } - first = false; - } - } - // do some work, simulating a REST API call - try { - Thread.sleep(2000); - } catch (InterruptedException e) {} - // Simulate response of created index - CreateIndexResponse response = new CreateIndexResponse(true, true, inputIndex); - future.complete(new WorkflowData(Map.of("index", response.index()))); - }); - - return future; - } - - @Override - public String getName() { - return name; - } -} diff --git a/src/main/java/demo/DataDemo.java b/src/main/java/demo/DataDemo.java deleted file mode 100644 index f2d606f07..000000000 --- a/src/main/java/demo/DataDemo.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package demo; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.common.SuppressForbidden; -import org.opensearch.common.io.PathUtils; -import org.opensearch.flowframework.template.ProcessNode; -import org.opensearch.flowframework.template.TemplateParser; -import org.opensearch.flowframework.workflow.WorkflowStep; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; - -/** - * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. - */ -public class DataDemo { - - private static final Logger logger = LogManager.getLogger(DataDemo.class); - - // This is temporary. We need a factory class to generate these workflow steps - // based on a field in the JSON. - private static Map workflowMap = new HashMap<>(); - static { - workflowMap.put("create_index", new CreateIndexWorkflowStep()); - workflowMap.put("create_another_index", new CreateIndexWorkflowStep()); - } - - /** - * Demonstrate parsing a JSON graph. - * - * @param args unused - */ - @SuppressForbidden(reason = "just a demo class that will be deleted") - public static void main(String[] args) { - String path = "src/test/resources/template/datademo.json"; - String json; - try { - json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); - } catch (IOException e) { - logger.error("Failed to read JSON at path {}", path); - return; - } - - logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); - List> futureList = new ArrayList<>(); - - for (ProcessNode n : processSequence) { - Set predecessors = n.getPredecessors(); - logger.info( - "Queueing process [{}].{}", - n.id(), - predecessors.isEmpty() - ? " Can start immediately!" - : String.format( - Locale.getDefault(), - " Must wait for [%s] to complete first.", - predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) - ) - ); - futureList.add(n.execute()); - } - futureList.forEach(CompletableFuture::join); - logger.info("All done!"); - } - -} diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 58d977827..53cf3499c 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -10,48 +10,41 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; -import org.opensearch.flowframework.template.ProcessNode; -import org.opensearch.flowframework.template.TemplateParser; -import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.stream.Collectors; /** - * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. + * Demo class exercising {@link WorkflowProcessSorter}. This will be moved to a unit test. */ public class Demo { private static final Logger logger = LogManager.getLogger(Demo.class); - // This is temporary. We need a factory class to generate these workflow steps - // based on a field in the JSON. - private static Map workflowMap = new HashMap<>(); - static { - workflowMap.put("fetch_model", new DemoWorkflowStep(3000)); - workflowMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000)); - workflowMap.put("create_search_pipeline", new DemoWorkflowStep(5000)); - workflowMap.put("create_neural_search_index", new DemoWorkflowStep(2000)); - } - /** * Demonstrate parsing a JSON graph. * * @param args unused + * @throws IOException on a failure */ @SuppressForbidden(reason = "just a demo class that will be deleted") - public static void main(String[] args) { + public static void main(String[] args) throws IOException { String path = "src/test/resources/template/demo.json"; String json; try { @@ -60,13 +53,18 @@ public static void main(String[] args) { logger.error("Failed to read JSON at path {}", path); return; } + Client client = new NodeClient(null, null); + WorkflowStepFactory factory = WorkflowStepFactory.create(client); + ExecutorService executor = Executors.newFixedThreadPool(10); + WorkflowProcessSorter.create(factory, executor); logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + Template t = Template.parse(json); + List processSequence = WorkflowProcessSorter.get().sortProcessNodes(t.workflows().get("demo")); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { - Set predecessors = n.getPredecessors(); + List predecessors = n.predecessors(); logger.info( "Queueing process [{}].{}", n.id(), @@ -78,11 +76,10 @@ public static void main(String[] args) { predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) ) ); - // TODO need to handle this better, passing an argument when we start them all at the beginning is silly futureList.add(n.execute()); } futureList.forEach(CompletableFuture::join); logger.info("All done!"); + executor.shutdown(); } - } diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java new file mode 100644 index 000000000..307d707c0 --- /dev/null +++ b/src/main/java/demo/TemplateParseDemo.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.PathUtils; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Map.Entry; +import java.util.concurrent.Executors; + +/** + * Demo class exercising {@link WorkflowProcessSorter}. This will be moved to a unit test. + */ +public class TemplateParseDemo { + + private static final Logger logger = LogManager.getLogger(TemplateParseDemo.class); + + /** + * Demonstrate parsing a JSON graph. + * + * @param args unused + * @throws IOException on error. + */ + @SuppressForbidden(reason = "just a demo class that will be deleted") + public static void main(String[] args) throws IOException { + String path = "src/test/resources/template/finaltemplate.json"; + String json; + try { + json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + logger.error("Failed to read JSON at path {}", path); + return; + } + Client client = new NodeClient(null, null); + WorkflowStepFactory factory = WorkflowStepFactory.create(client); + WorkflowProcessSorter.create(factory, Executors.newFixedThreadPool(10)); + + Template t = Template.parse(json); + + System.out.println(t.toJson()); + System.out.println(t.toYaml()); + + for (Entry e : t.workflows().entrySet()) { + logger.info("Parsing {} workflow.", e.getKey()); + WorkflowProcessSorter.get().sortProcessNodes(e.getValue()); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index e5df0bf46..d701c832e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -16,8 +16,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; -import org.opensearch.flowframework.workflow.CreateIndex.CreateIndexStep; -import org.opensearch.flowframework.workflow.CreateIngestPipelineStep; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; @@ -32,8 +32,6 @@ */ public class FlowFrameworkPlugin extends Plugin { - private Client client; - @Override public Collection createComponents( Client client, @@ -48,9 +46,9 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - this.client = client; - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); - CreateIndexStep createIndexStep = new CreateIndexStep(client); - return ImmutableList.of(createIngestPipelineStep, createIndexStep); + WorkflowStepFactory workflowStepFactory = WorkflowStepFactory.create(client); + WorkflowProcessSorter workflowProcessSorter = WorkflowProcessSorter.create(workflowStepFactory, threadPool.generic()); + + return ImmutableList.of(workflowStepFactory, workflowProcessSorter); } } diff --git a/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java new file mode 100644 index 000000000..1407036b3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents a processor associated with search and ingest pipelines in the {@link Template}. + */ +public class PipelineProcessor implements ToXContentObject { + + /** The type field name for pipeline processors */ + public static final String TYPE_FIELD = "type"; + /** The params field name for pipeline processors */ + public static final String PARAMS_FIELD = "params"; + + private final String type; + private final Map params; + + /** + * Create this processor with a type and map of parameters + * @param type the processor type + * @param params a map of params + */ + public PipelineProcessor(String type, Map params) { + this.type = type; + this.params = params; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(TYPE_FIELD, this.type); + xContentBuilder.field(PARAMS_FIELD); + Template.buildStringToStringMap(xContentBuilder, this.params); + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into a processor instance. + * + * @param parser json based content parser + * @return the parsed PipelineProcessor instance + * @throws IOException if content can't be parsed correctly + */ + public static PipelineProcessor parse(XContentParser parser) throws IOException { + String type = null; + Map params = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case PARAMS_FIELD: + params = Template.parseStringToStringMap(parser); + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a pipeline processor object."); + } + } + if (type == null) { + throw new IOException("A processor object requires a type field."); + } + + return new PipelineProcessor(type, params); + } + + /** + * Get the processor type + * @return the type + */ + public String type() { + return type; + } + + /** + * Get the processor params + * @return the params + */ + public Map params() { + return params; + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java new file mode 100644 index 000000000..dd998aefa --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -0,0 +1,393 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.Version; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.common.xcontent.yaml.YamlXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. + */ +public class Template implements ToXContentObject { + + /** The template field name for template name */ + public static final String NAME_FIELD = "name"; + /** The template field name for template description */ + public static final String DESCRIPTION_FIELD = "description"; + /** The template field name for template use case */ + public static final String USE_CASE_FIELD = "use_case"; + /** The template field name for template operations */ + public static final String OPERATIONS_FIELD = "operations"; + /** The template field name for template version information */ + public static final String VERSION_FIELD = "version"; + /** The template field name for template version */ + public static final String TEMPLATE_FIELD = "template"; + /** The template field name for template compatibility with OpenSearch versions */ + public static final String COMPATIBILITY_FIELD = "compatibility"; + /** The template field name for template user inputs */ + public static final String USER_INPUTS_FIELD = "user_inputs"; + /** The template field name for template workflows */ + public static final String WORKFLOWS_FIELD = "workflows"; + + private final String name; + private final String description; + private final String useCase; // probably an ENUM actually + private final List operations; // probably an ENUM actually + private final Version templateVersion; + private final List compatibilityVersion; + private final Map userInputs; + private final Map workflows; + + /** + * Instantiate the object representing a use case template + * + * @param name The template's name + * @param description A description of the template's use case + * @param useCase A string defining the internal use case type + * @param operations Expected operations of this template. Should match defined workflows. + * @param templateVersion The version of this template + * @param compatibilityVersion OpenSearch version compatibility of this template + * @param userInputs Optional user inputs to apply globally + * @param workflows Workflow graph definitions corresponding to the defined operations. + */ + public Template( + String name, + String description, + String useCase, + List operations, + Version templateVersion, + List compatibilityVersion, + Map userInputs, + Map workflows + ) { + this.name = name; + this.description = description; + this.useCase = useCase; + this.operations = List.copyOf(operations); + this.templateVersion = templateVersion; + this.compatibilityVersion = List.copyOf(compatibilityVersion); + this.userInputs = Map.copyOf(userInputs); + this.workflows = Map.copyOf(workflows); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(NAME_FIELD, this.name); + xContentBuilder.field(DESCRIPTION_FIELD, this.description); + xContentBuilder.field(USE_CASE_FIELD, this.useCase); + xContentBuilder.startArray(OPERATIONS_FIELD); + for (String op : this.operations) { + xContentBuilder.value(op); + } + xContentBuilder.endArray(); + + if (this.templateVersion != null || !this.compatibilityVersion.isEmpty()) { + xContentBuilder.startObject(VERSION_FIELD); + if (this.templateVersion != null) { + xContentBuilder.field(TEMPLATE_FIELD, this.templateVersion); + } + if (!this.compatibilityVersion.isEmpty()) { + xContentBuilder.startArray(COMPATIBILITY_FIELD); + for (Version v : this.compatibilityVersion) { + xContentBuilder.value(v); + } + xContentBuilder.endArray(); + } + xContentBuilder.endObject(); + } + + if (!this.userInputs.isEmpty()) { + xContentBuilder.startObject(USER_INPUTS_FIELD); + for (Entry e : userInputs.entrySet()) { + xContentBuilder.field(e.getKey(), e.getValue()); + } + xContentBuilder.endObject(); + } + + xContentBuilder.startObject(WORKFLOWS_FIELD); + for (Entry e : workflows.entrySet()) { + xContentBuilder.field(e.getKey(), e.getValue(), params); + } + xContentBuilder.endObject(); + + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into a Template instance. + * + * @param parser json based content parser + * @return an instance of the template + * @throws IOException if content can't be parsed correctly + */ + public static Template parse(XContentParser parser) throws IOException { + String name = null; + String description = ""; + String useCase = ""; + List operations = new ArrayList<>(); + Version templateVersion = null; + List compatibilityVersion = new ArrayList<>(); + Map userInputs = new HashMap<>(); + Map workflows = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case USE_CASE_FIELD: + useCase = parser.text(); + break; + case OPERATIONS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + operations.add(parser.text()); + } + break; + case VERSION_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String versionFieldName = parser.currentName(); + parser.nextToken(); + switch (versionFieldName) { + case TEMPLATE_FIELD: + templateVersion = Version.fromString(parser.text()); + break; + case COMPATIBILITY_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + compatibilityVersion.add(Version.fromString(parser.text())); + } + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a version object."); + } + } + break; + case USER_INPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String inputFieldName = parser.currentName(); + switch (parser.nextToken()) { + case VALUE_STRING: + userInputs.put(inputFieldName, parser.text()); + break; + case START_OBJECT: + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + break; + default: + throw new IOException("Unable to parse field [" + inputFieldName + "] in a user inputs object."); + } + } + break; + case WORKFLOWS_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String workflowFieldName = parser.currentName(); + parser.nextToken(); + workflows.put(workflowFieldName, Workflow.parse(parser)); + } + break; + + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a template object."); + } + } + if (name == null) { + throw new IOException("An template object requires a name."); + } + + return new Template(name, description, useCase, operations, templateVersion, compatibilityVersion, userInputs, workflows); + } + + /** + * Parse a JSON use case template + * + * @param json A string containing a JSON representation of a use case template + * @return A {@link Template} represented by the JSON. + * @throws IOException on failure to parse + */ + public static Template parse(String json) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return parse(parser); + } + + /** + * Builds an XContent object representing a map of String keys to String values. + * + * @param xContentBuilder An XContent builder whose position is at the start of the map object to build + * @param map A map as key-value String pairs. + * @throws IOException on a build failure + */ + public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map map) throws IOException { + xContentBuilder.startObject(); + for (Entry e : map.entrySet()) { + xContentBuilder.field((String) e.getKey(), (String) e.getValue()); + } + xContentBuilder.endObject(); + } + + /** + * Parses an XContent object representing a map of String keys to String values. + * + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return A map as identified by the key-value pairs in the XContent + * @throws IOException on a parse failure + */ + public static Map parseStringToStringMap(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Map map = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + map.put(fieldName, parser.text()); + } + return map; + } + + /** + * Output this object in a compact JSON string. + * + * @return a JSON representation of the template. + */ + public String toJson() { + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + return this.toXContent(builder, EMPTY_PARAMS).toString(); + } catch (IOException e) { + return "{\"error\": \"couldn't create JSON: " + e.getMessage() + "\"}"; + } + } + + /** + * Output this object in YAML. + * + * @return a YAML representation of the template. + */ + public String toYaml() { + try { + XContentBuilder builder = YamlXContent.contentBuilder(); + return this.toXContent(builder, EMPTY_PARAMS).toString(); + } catch (IOException e) { + return "error: couldn't create YAML: " + e.getMessage(); + } + } + + /** + * The name of this template + * @return the name + */ + public String name() { + return name; + } + + /** + * A description of what this template does + * @return the description + */ + public String description() { + return description; + } + + /** + * A canonical use case name for this template + * @return the useCase + */ + public String useCase() { + return useCase; + } + + /** + * Operations this use case supports + * @return the operations + */ + public List operations() { + return operations; + } + + /** + * The version of this template + * @return the templateVersion + */ + public Version templateVersion() { + return templateVersion; + } + + /** + * OpenSearch version compatibility of this template + * @return the compatibilityVersion + */ + public List compatibilityVersion() { + return compatibilityVersion; + } + + /** + * A map of user inputs + * @return the userInputs + */ + public Map userInputs() { + return userInputs; + } + + /** + * Workflows encoded in this template, generally corresponding to the operations returned by {@link #operations()}. + * @return the workflows + */ + public Map workflows() { + return workflows; + } + + @Override + public String toString() { + return "Template [name=" + + name + + ", description=" + + description + + ", useCase=" + + useCase + + ", operations=" + + operations + + ", templateVersion=" + + templateVersion + + ", compatibilityVersion=" + + compatibilityVersion + + ", userInputs=" + + userInputs + + ", workflows=" + + workflows + + "]"; + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/Workflow.java b/src/main/java/org/opensearch/flowframework/model/Workflow.java new file mode 100644 index 000000000..81f2677a7 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/Workflow.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.workflow.WorkflowData; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents an object in the workflows section of a {@link Template}. + */ +public class Workflow implements ToXContentObject { + + /** The template field name for workflow user params */ + public static final String USER_PARAMS_FIELD = "user_params"; + /** The template field name for workflow nodes */ + public static final String NODES_FIELD = "nodes"; + /** The template field name for workflow edges */ + public static final String EDGES_FIELD = "edges"; + + private final Map userParams; + private final List nodes; + private final List edges; + + /** + * Create this workflow with any user params and the graph of nodes and edges. + * + * @param userParams A map of user params. + * @param nodes An array of {@link WorkflowNode} objects + * @param edges An array of {@link WorkflowEdge} objects. + */ + public Workflow(Map userParams, List nodes, List edges) { + this.userParams = Map.copyOf(userParams); + this.nodes = List.copyOf(nodes); + this.edges = List.copyOf(edges); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + xContentBuilder.startObject(USER_PARAMS_FIELD); + for (Entry e : userParams.entrySet()) { + xContentBuilder.field(e.getKey(), e.getValue()); + } + xContentBuilder.endObject(); + + xContentBuilder.startArray(NODES_FIELD); + for (WorkflowNode n : nodes) { + xContentBuilder.value(n); + } + xContentBuilder.endArray(); + + xContentBuilder.startArray(EDGES_FIELD); + for (WorkflowEdge e : edges) { + xContentBuilder.value(e); + } + xContentBuilder.endArray(); + + return xContentBuilder.endObject(); + } + + /** + * Parse raw JSON content into a workflow instance. + * + * @param parser JSON based content parser + * @return the parsed Workflow instance + * @throws IOException if content can't be parsed correctly + */ + public static Workflow parse(XContentParser parser) throws IOException { + Map userParams = new HashMap<>(); + List nodes = new ArrayList<>(); + List edges = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case USER_PARAMS_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String userParamFieldName = parser.currentName(); + parser.nextToken(); + userParams.put(userParamFieldName, parser.text()); + } + break; + case NODES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + nodes.add(WorkflowNode.parse(parser)); + } + break; + case EDGES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + edges.add(WorkflowEdge.parse(parser)); + } + break; + } + + } + if (nodes.isEmpty()) { + throw new IOException("A workflow must have at least one node."); + } + if (edges.isEmpty()) { + // infer edges from sequence of nodes + // Start iteration at 1, will skip for a one-node array + for (int i = 1; i < nodes.size(); i++) { + edges.add(new WorkflowEdge(nodes.get(i - 1).id(), nodes.get(i).id())); + } + } + return new Workflow(userParams, nodes, edges); + } + + /** + * Get user parameters. These will be passed to all workflow nodes and available as {@link WorkflowData#getParams()} + * @return the userParams + */ + public Map userParams() { + return userParams; + } + + /** + * Get the nodes in the workflow. Ordering matches the user template which may or may not match execution order. + * @return the nodes + */ + public List nodes() { + return nodes; + } + + /** + * Get the edges in the workflow. These specify connections of nodes which form a graph. + * @return the edges + */ + public List edges() { + return edges; + } + + @Override + public String toString() { + return "Workflow [userParams=" + userParams + ", nodes=" + nodes + ", edges=" + edges + "]"; + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowEdge.java b/src/main/java/org/opensearch/flowframework/model/WorkflowEdge.java new file mode 100644 index 000000000..7fbdaf568 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowEdge.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents an edge between process nodes (steps) in a workflow graph in the {@link Template}. + */ +public class WorkflowEdge implements ToXContentObject { + + /** The template field name for source node */ + public static final String SOURCE_FIELD = "source"; + /** The template field name for destination node */ + public static final String DEST_FIELD = "dest"; + + private final String source; + private final String destination; + + /** + * Create this edge with the id's of the source and destination nodes. + * + * @param source The source node id. + * @param destination The destination node id. + */ + public WorkflowEdge(String source, String destination) { + this.source = source; + this.destination = destination; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(SOURCE_FIELD, this.source); + xContentBuilder.field(DEST_FIELD, this.destination); + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into a workflow edge instance. + * + * @param parser json based content parser + * @return the parsed WorkflowEdge instance + * @throws IOException if content can't be parsed correctly + */ + public static WorkflowEdge parse(XContentParser parser) throws IOException { + String source = null; + String destination = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case SOURCE_FIELD: + source = parser.text(); + break; + case DEST_FIELD: + destination = parser.text(); + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in an edge object."); + } + } + if (source == null || destination == null) { + throw new IOException("An edge object requires both a source and dest field."); + } + + return new WorkflowEdge(source, destination); + } + + /** + * Gets the source node id. + * + * @return the source node id. + */ + public String source() { + return source; + } + + /** + * Gets the destination node id. + * + * @return the destination node id. + */ + public String destination() { + return destination; + } + + @Override + public int hashCode() { + return Objects.hash(destination, source); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + WorkflowEdge other = (WorkflowEdge) obj; + return Objects.equals(destination, other.destination) && Objects.equals(source, other.source); + } + + @Override + public String toString() { + return this.source + "->" + this.destination; + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java new file mode 100644 index 000000000..b48b6e0d2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -0,0 +1,201 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents a process node (step) in a workflow graph in the {@link Template}. + * It will have a one-to-one correspondence with a {@link ProcessNode}, + * where its type is used to determine the correct {@link WorkflowStep} object, + * and its inputs are used to populate the {@link WorkflowData} input. + */ +public class WorkflowNode implements ToXContentObject { + + /** The template field name for node id */ + public static final String ID_FIELD = "id"; + /** The template field name for node type */ + public static final String TYPE_FIELD = "type"; + /** The template field name for node inputs */ + public static final String INPUTS_FIELD = "inputs"; + /** The field defining processors in the inputs for search and ingest pipelines */ + public static final String PROCESSORS_FIELD = "processors"; + + private final String id; // unique id + private final String type; // maps to a WorkflowStep + private final Map inputs; // maps to WorkflowData + + /** + * Create this node with the id and type, and any user input. + * + * @param id A unique string identifying this node + * @param type The type of {@link WorkflowStep} to create for the corresponding {@link ProcessNode} + * @param inputs Optional input to populate params in {@link WorkflowData} + */ + public WorkflowNode(String id, String type, Map inputs) { + this.id = id; + this.type = type; + this.inputs = Map.copyOf(inputs); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(ID_FIELD, this.id); + xContentBuilder.field(TYPE_FIELD, this.type); + + xContentBuilder.startObject(INPUTS_FIELD); + for (Entry e : inputs.entrySet()) { + xContentBuilder.field(e.getKey()); + if (e.getValue() instanceof String) { + xContentBuilder.value(e.getValue()); + } else if (e.getValue() instanceof Map) { + Template.buildStringToStringMap(xContentBuilder, (Map) e.getValue()); + } else if (e.getValue() instanceof Object[]) { + xContentBuilder.startArray(); + if (PROCESSORS_FIELD.equals(e.getKey())) { + for (PipelineProcessor p : (PipelineProcessor[]) e.getValue()) { + xContentBuilder.value(p); + } + } else { + for (Map map : (Map[]) e.getValue()) { + Template.buildStringToStringMap(xContentBuilder, map); + } + } + xContentBuilder.endArray(); + } + } + xContentBuilder.endObject(); + + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into a workflow node instance. + * + * @param parser json based content parser + * @return the parsed WorkflowNode instance + * @throws IOException if content can't be parsed correctly + */ + public static WorkflowNode parse(XContentParser parser) throws IOException { + String id = null; + String type = null; + Map inputs = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case ID_FIELD: + id = parser.text(); + break; + case TYPE_FIELD: + type = parser.text(); + break; + case INPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String inputFieldName = parser.currentName(); + switch (parser.nextToken()) { + case VALUE_STRING: + inputs.put(inputFieldName, parser.text()); + break; + case START_OBJECT: + inputs.put(inputFieldName, Template.parseStringToStringMap(parser)); + break; + case START_ARRAY: + if (PROCESSORS_FIELD.equals(inputFieldName)) { + List processorList = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + processorList.add(PipelineProcessor.parse(parser)); + } + inputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); + } else { + List> mapList = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + mapList.add(Template.parseStringToStringMap(parser)); + } + inputs.put(inputFieldName, mapList.toArray(new Map[0])); + } + break; + default: + throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object."); + } + } + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a node object."); + } + } + if (id == null || type == null) { + throw new IOException("An node object requires both an id and type field."); + } + + return new WorkflowNode(id, type, inputs); + } + + /** + * Return this node's id + * @return the id + */ + public String id() { + return id; + } + + /** + * Return this node's type + * @return the type + */ + public String type() { + return type; + } + + /** + * Return this node's input data + * @return the inputs + */ + public Map inputs() { + return inputs; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + WorkflowNode other = (WorkflowNode) obj; + return Objects.equals(id, other.id); + } + + @Override + public String toString() { + return this.id; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java deleted file mode 100644 index 9544620fb..000000000 --- a/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.template; - -import java.util.Objects; - -/** - * Representation of an edge between process nodes in a workflow graph. - */ -public class ProcessSequenceEdge { - private final String source; - private final String destination; - - /** - * Create this edge with the id's of the source and destination nodes. - * - * @param source The source node id. - * @param destination The destination node id. - */ - ProcessSequenceEdge(String source, String destination) { - this.source = source; - this.destination = destination; - } - - /** - * Gets the source node id. - * - * @return the source node id. - */ - public String getSource() { - return source; - } - - /** - * Gets the destination node id. - * - * @return the destination node id. - */ - public String getDestination() { - return destination; - } - - @Override - public int hashCode() { - return Objects.hash(destination, source); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null) return false; - if (getClass() != obj.getClass()) return false; - ProcessSequenceEdge other = (ProcessSequenceEdge) obj; - return Objects.equals(destination, other.destination) && Objects.equals(source, other.source); - } - - @Override - public String toString() { - return this.source + "->" + this.destination; - } -} diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java deleted file mode 100644 index 56635f1b4..000000000 --- a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.template; - -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; - -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -/** - * Utility class for parsing templates. - */ -public class TemplateParser { - - private static final Logger logger = LogManager.getLogger(TemplateParser.class); - - // Field names in the JSON. Package private for tests. - static final String WORKFLOW = "sequence"; - static final String NODES = "nodes"; - static final String NODE_ID = "id"; - static final String EDGES = "edges"; - static final String SOURCE = "source"; - static final String DESTINATION = "dest"; - - /** - * Prevent instantiating this class. - */ - private TemplateParser() {} - - /** - * Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes. - * @param json A string containing a JSON representation of nodes and edges - * @param workflowSteps A map linking JSON node names to Java objects implementing {@link WorkflowStep} - * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. - */ - public static List parseJsonGraphToSequence(String json, Map workflowSteps) { - Gson gson = new Gson(); - JsonObject jsonObject = gson.fromJson(json, JsonObject.class); - - JsonObject graph = jsonObject.getAsJsonObject(WORKFLOW); - - List nodes = new ArrayList<>(); - List edges = new ArrayList<>(); - - for (JsonElement nodeJson : graph.getAsJsonArray(NODES)) { - JsonObject nodeObject = nodeJson.getAsJsonObject(); - String nodeId = nodeObject.get(NODE_ID).getAsString(); - // The below steps will be replaced by a generator class that instantiates a WorkflowStep - // based on user_input data from the template. - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/41 - WorkflowStep workflowStep = workflowSteps.get(nodeId); - // temporary demo POC of getting from a request to input data - // this will be refactored into something pulling from user template as part of the above issue - WorkflowData inputData = WorkflowData.EMPTY; - if (List.of("create_index", "create_another_index").contains(nodeId)) { - CreateIndexRequest request = new CreateIndexRequest(nodeObject.get("index_name").getAsString()); - inputData = new WorkflowData( - Map.of("mappings", request.mappings(), "settings", request.settings(), "aliases", request.aliases()), - Map.of("index", request.index()) - ); - } - nodes.add(new ProcessNode(nodeId, workflowStep, inputData)); - } - - for (JsonElement edgeJson : graph.getAsJsonArray(EDGES)) { - JsonObject edgeObject = edgeJson.getAsJsonObject(); - String sourceNodeId = edgeObject.get(SOURCE).getAsString(); - String destNodeId = edgeObject.get(DESTINATION).getAsString(); - if (sourceNodeId.equals(destNodeId)) { - throw new IllegalArgumentException("Edge connects node " + sourceNodeId + " to itself."); - } - edges.add(new ProcessSequenceEdge(sourceNodeId, destNodeId)); - } - - return topologicalSort(nodes, edges); - } - - private static List topologicalSort(List nodes, List edges) { - // Define the graph - Set graph = new HashSet<>(edges); - // Map node id string to object - Map nodeMap = nodes.stream().collect(Collectors.toMap(ProcessNode::id, Function.identity())); - // Build predecessor and successor maps - Map> predecessorEdges = new HashMap<>(); - Map> successorEdges = new HashMap<>(); - for (ProcessSequenceEdge edge : edges) { - ProcessNode source = nodeMap.get(edge.getSource()); - ProcessNode dest = nodeMap.get(edge.getDestination()); - predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); - successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); - } - // update predecessors on the node object - nodes.stream().filter(n -> predecessorEdges.containsKey(n)).forEach(n -> { - n.setPredecessors(predecessorEdges.get(n).stream().map(e -> nodeMap.get(e.getSource())).collect(Collectors.toSet())); - }); - - // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - // L <- Empty list that will contain the sorted elements - List sortedNodes = new ArrayList<>(); - // S <- Set of all nodes with no incoming edge - Queue sourceNodes = new ArrayDeque<>(); - nodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); - if (sourceNodes.isEmpty()) { - throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); - } - logger.debug("Start node(s): {}", sourceNodes); - - // while S is not empty do - while (!sourceNodes.isEmpty()) { - // remove a node n from S - ProcessNode n = sourceNodes.poll(); - // add n to L - sortedNodes.add(n); - // for each node m with an edge e from n to m do - for (ProcessSequenceEdge e : successorEdges.getOrDefault(n, Collections.emptySet())) { - ProcessNode m = nodeMap.get(e.getDestination()); - // remove edge e from the graph - graph.remove(e); - // if m has no other incoming edges then - if (!predecessorEdges.get(m).stream().anyMatch(i -> graph.contains(i))) { - // insert m into S - sourceNodes.add(m); - } - } - } - if (!graph.isEmpty()) { - throw new IllegalArgumentException("Cycle detected: " + graph); - } - logger.debug("Execution sequence: {}", sortedNodes); - return sortedNodes; - } -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java similarity index 91% rename from src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java rename to src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 7f92b8057..1f0d074c2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -6,7 +6,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.workflow.CreateIndex; +package org.opensearch.flowframework.workflow; import com.google.common.base.Charsets; import com.google.common.io.Resources; @@ -16,10 +16,8 @@ import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; import java.io.IOException; import java.net.URL; @@ -34,7 +32,9 @@ public class CreateIndexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); private Client client; - private final String NAME = "create_index_step"; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + static final String NAME = "create_index"; /** * Instantiate this class @@ -81,7 +81,7 @@ public void onFailure(Exception e) { try { CreateIndexRequest request = new CreateIndexRequest(index).mapping( getIndexMappings("mappings/" + type + ".json"), - XContentType.JSON + JsonXContent.jsonXContent.mediaType() ); client.admin().indices().create(request, actionListener); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 8382925b2..4770b94a9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -32,7 +32,9 @@ public class CreateIngestPipelineStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); - private static final String NAME = "create_ingest_pipeline_step"; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + static final String NAME = "create_ingest_pipeline"; // Common pipeline configuration fields private static final String PIPELINE_ID_FIELD = "id"; diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java similarity index 70% rename from src/main/java/org/opensearch/flowframework/template/ProcessNode.java rename to src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 08a7ec841..a2d7628c3 100644 --- a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -6,20 +6,16 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.template; +package org.opensearch.flowframework.workflow; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -33,32 +29,26 @@ public class ProcessNode { private final String id; private final WorkflowStep workflowStep; private final WorkflowData input; - private CompletableFuture future = null; + private final List predecessors; + private Executor executor; - // will be populated during graph parsing - private Set predecessors = Collections.emptySet(); + private final CompletableFuture future = new CompletableFuture<>(); /** - * Create this node linked to its executing process. + * Create this node linked to its executing process, including input data and any predecessor nodes. * * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + * @param input Input required by the node encoded in a {@link WorkflowData} instance. + * @param predecessors Nodes preceding this one in the workflow + * @param executor The OpenSearch thread pool */ - ProcessNode(String id, WorkflowStep workflowStep) { - this(id, workflowStep, WorkflowData.EMPTY); - } - - /** - * Create this node linked to its executing process. - * - * @param id A string identifying the workflow step - * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. - * @param input Input required by the node - */ - public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input) { + public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input, List predecessors, Executor executor) { this.id = id; this.workflowStep = workflowStep; this.input = input; + this.predecessors = predecessors; + this.executor = executor; } /** @@ -92,41 +82,31 @@ public WorkflowData input() { * @return A future indicating the processing state of this node. * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. */ - public CompletableFuture getFuture() { + public CompletableFuture future() { return future; } /** * Returns the predecessors of this node in the workflow. - * The predecessor's {@link #getFuture()} must complete before execution begins on this node. + * The predecessor's {@link #future()} must complete before execution begins on this node. * * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. */ - public Set getPredecessors() { + public List predecessors() { return predecessors; } - /** - * Sets the predecessor node. Called by {@link TemplateParser}. - * - * @param predecessors The predecessors of this node. - */ - void setPredecessors(Set predecessors) { - this.predecessors = Set.copyOf(predecessors); - } - /** * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. * * @return this node's future. This is returned immediately, while process execution continues asynchronously. */ public CompletableFuture execute() { - this.future = new CompletableFuture<>(); // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) // the generic executor from that pool should be passed to this runAsync call // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 CompletableFuture.runAsync(() -> { - List> predFutures = predecessors.stream().map(p -> p.getFuture()).collect(Collectors.toList()); + List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); if (!predecessors.isEmpty()) { CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); try { @@ -152,34 +132,20 @@ public CompletableFuture execute() { } CompletableFuture stepFuture = this.workflowStep.execute(input); try { - stepFuture.join(); + stepFuture.orTimeout(15, TimeUnit.SECONDS).join(); + logger.info(">>> Finished {}.", this.id); future.complete(stepFuture.get()); - logger.debug("<<< Completed {}", this.id); } catch (InterruptedException | ExecutionException e) { handleException(e); } - }); + }, executor); return this.future; } private void handleException(Exception e) { // TODO: better handling of getCause this.future.completeExceptionally(e); - logger.debug("<<< Completed Exceptionally {}", this.id); - } - - @Override - public int hashCode() { - return Objects.hash(id); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null) return false; - if (getClass() != obj.getClass()) return false; - ProcessNode other = (ProcessNode) obj; - return Objects.equals(id, other.id); + logger.debug("<<< Completed Exceptionally {}", this.id, e.getCause()); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java new file mode 100644 index 000000000..3370f6384 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility class converting a workflow of nodes and edges into a topologically sorted list of Process Nodes. + */ +public class WorkflowProcessSorter { + + private static final Logger logger = LogManager.getLogger(WorkflowProcessSorter.class); + + private static WorkflowProcessSorter instance = null; + + private WorkflowStepFactory workflowStepFactory; + private Executor executor; + + /** + * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. + * + * @param workflowStepFactory The singleton instance of {@link WorkflowStepFactory} + * @param executor A thread executor + * @return The created instance + */ + public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, Executor executor) { + if (instance != null) { + throw new IllegalStateException("This class was already created."); + } + instance = new WorkflowProcessSorter(workflowStepFactory, executor); + return instance; + } + + /** + * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. + * + * @return The created instance + */ + public static synchronized WorkflowProcessSorter get() { + if (instance == null) { + throw new IllegalStateException("This factory has not yet been created."); + } + return instance; + } + + private WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, Executor executor) { + this.workflowStepFactory = workflowStepFactory; + this.executor = executor; + } + + /** + * Sort a workflow into a topologically sorted list of process nodes. + * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors + * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. + */ + public List sortProcessNodes(Workflow workflow) { + List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); + + List nodes = new ArrayList<>(); + Map idToNodeMap = new HashMap<>(); + for (WorkflowNode node : sortedNodes) { + WorkflowStep step = workflowStepFactory.createStep(node.type()); + WorkflowData data = new WorkflowData(node.inputs(), workflow.userParams()); + List predecessorNodes = workflow.edges() + .stream() + .filter(e -> e.destination().equals(node.id())) + // since we are iterating in topological order we know all predecessors will be in the map + .map(e -> idToNodeMap.get(e.source())) + .collect(Collectors.toList()); + + ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, executor); + idToNodeMap.put(processNode.id(), processNode); + nodes.add(processNode); + } + + return nodes; + } + + private static List topologicalSort(List workflowNodes, List workflowEdges) { + // Basic validation + Set nodeIds = workflowNodes.stream().map(n -> n.id()).collect(Collectors.toSet()); + for (WorkflowEdge edge : workflowEdges) { + String source = edge.source(); + if (!nodeIds.contains(source)) { + throw new IllegalArgumentException("Edge source " + source + " does not correspond to a node."); + } + String dest = edge.destination(); + if (!nodeIds.contains(dest)) { + throw new IllegalArgumentException("Edge destination " + dest + " does not correspond to a node."); + } + if (source.equals(dest)) { + throw new IllegalArgumentException("Edge connects node " + source + " to itself."); + } + } + + // Build predecessor and successor maps + Map> predecessorEdges = new HashMap<>(); + Map> successorEdges = new HashMap<>(); + Map nodeMap = workflowNodes.stream().collect(Collectors.toMap(WorkflowNode::id, Function.identity())); + for (WorkflowEdge edge : workflowEdges) { + WorkflowNode source = nodeMap.get(edge.source()); + WorkflowNode dest = nodeMap.get(edge.destination()); + predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); + successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); + } + + // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + Set graph = new HashSet<>(workflowEdges); + // L <- Empty list that will contain the sorted elements + List sortedNodes = new ArrayList<>(); + // S <- Set of all nodes with no incoming edge + Queue sourceNodes = new ArrayDeque<>(); + workflowNodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); + if (sourceNodes.isEmpty()) { + throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + } + logger.debug("Start node(s): {}", sourceNodes); + + // while S is not empty do + while (!sourceNodes.isEmpty()) { + // remove a node n from S + WorkflowNode n = sourceNodes.poll(); + // add n to L + sortedNodes.add(n); + // for each node m with an edge e from n to m do + for (WorkflowEdge e : successorEdges.getOrDefault(n, Collections.emptySet())) { + WorkflowNode m = nodeMap.get(e.destination()); + // remove edge e from the graph + graph.remove(e); + // if m has no other incoming edges then + if (!predecessorEdges.get(m).stream().anyMatch(i -> graph.contains(i))) { + // insert m into S + sourceNodes.add(m); + } + } + } + if (!graph.isEmpty()) { + throw new IllegalArgumentException("Cycle detected: " + graph); + } + logger.debug("Execution sequence: {}", sortedNodes); + return sortedNodes; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java new file mode 100644 index 000000000..dc0dc29a2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.client.Client; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import demo.DemoWorkflowStep; + +/** + * Generates instances implementing {@link WorkflowStep}. + */ +public class WorkflowStepFactory { + + private static WorkflowStepFactory instance = null; + + private final Map stepMap = new HashMap<>(); + + /** + * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. + * + * @param client The OpenSearch client steps can use + * @return The created instance + */ + public static synchronized WorkflowStepFactory create(Client client) { + if (instance != null) { + throw new IllegalStateException("This factory was already created."); + } + instance = new WorkflowStepFactory(client); + return instance; + } + + /** + * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. + * + * @return The created instance + */ + public static synchronized WorkflowStepFactory get() { + if (instance == null) { + throw new IllegalStateException("This factory has not yet been created."); + } + return instance; + } + + private WorkflowStepFactory(Client client) { + populateMap(client); + } + + private void populateMap(Client client) { + stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(client)); + stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); + + // TODO: These are from the demo class as placeholders, remove when demos are deleted + stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); + stepMap.put("demo_delay_5", new DemoWorkflowStep(3000)); + + // Use as a default until all the actual implementations are ready + stepMap.put("placeholder", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + future.complete(WorkflowData.EMPTY); + return future; + } + + @Override + public String getName() { + return "placeholder"; + } + }); + } + + /** + * Create a new instance of a {@link WorkflowStep}. + * @param type The type of instance to create + * @return an instance of the specified type + */ + public WorkflowStep createStep(String type) { + if (stepMap.containsKey(type)) { + return stepMap.get(type); + } + // TODO: replace this with a FlowFrameworkException + // https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/43 + return stepMap.get("placeholder"); + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/PipelineProcessorTests.java b/src/test/java/org/opensearch/flowframework/model/PipelineProcessorTests.java new file mode 100644 index 000000000..5e9a81d0d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/PipelineProcessorTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Map; + +public class PipelineProcessorTests extends OpenSearchTestCase { + + public void testProcessor() throws IOException { + PipelineProcessor processor = new PipelineProcessor("foo", Map.of("bar", "baz")); + + assertEquals("foo", processor.type()); + assertEquals(Map.of("bar", "baz"), processor.params()); + + String expectedJson = "{\"type\":\"foo\",\"params\":{\"bar\":\"baz\"}}"; + String json = TemplateTestJsonUtil.parseToJson(processor); + assertEquals(expectedJson, json); + + PipelineProcessor processorX = PipelineProcessor.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals("foo", processorX.type()); + assertEquals(Map.of("bar", "baz"), processorX.params()); + } + + public void testExceptions() throws IOException { + String badJson = "{\"badField\":\"foo\",\"params\":{\"bar\":\"baz\"}}"; + IOException e = assertThrows(IOException.class, () -> PipelineProcessor.parse(TemplateTestJsonUtil.jsonToParser(badJson))); + assertEquals("Unable to parse field [badField] in a pipeline processor object.", e.getMessage()); + + String noTypeJson = "{\"params\":{\"bar\":\"baz\"}}"; + e = assertThrows(IOException.class, () -> PipelineProcessor.parse(TemplateTestJsonUtil.jsonToParser(noTypeJson))); + assertEquals("A processor object requires a type field.", e.getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java new file mode 100644 index 000000000..247521084 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * Utility methods for tests of template JSON + */ +public class TemplateTestJsonUtil { + + public static String node(String id) { + return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + "placeholder" + "\"}"; + } + + public static String edge(String sourceId, String destId) { + return "{\"" + WorkflowEdge.SOURCE_FIELD + "\": \"" + sourceId + "\", \"" + WorkflowEdge.DEST_FIELD + "\": \"" + destId + "\"}"; + } + + public static String workflow(List nodes, List edges) { + return "{\"workflow\": {" + arrayField(Workflow.NODES_FIELD, nodes) + ", " + arrayField(Workflow.EDGES_FIELD, edges) + "}}"; + } + + private static String arrayField(String fieldName, List objects) { + return "\"" + fieldName + "\": [" + objects.stream().collect(Collectors.joining(", ")) + "]"; + } + + public static String parseToJson(ToXContentObject object) throws IOException { + return object.toXContent(JsonXContent.contentBuilder(), ToXContent.EMPTY_PARAMS).toString(); + } + + public static XContentParser jsonToParser(String json) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return parser; + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java new file mode 100644 index 000000000..69f14dfaf --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.Version; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class TemplateTests extends OpenSearchTestCase { + + private String expectedPrefix = + "{\"name\":\"test\",\"description\":\"a test template\",\"use_case\":\"test use case\",\"operations\":[\"operation\"]," + + "\"version\":{\"template\":\"1.2.3\",\"compatibility\":[\"4.5.6\",\"7.8.9\"]},\"user_inputs\":{"; + private String expectedKV1 = "\"userKey\":\"userValue\""; + private String expectedKV2 = "\"userMapKey\":{\"nestedKey\":\"nestedValue\"}"; + private String expectedSuffix = "},\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"}," + + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}," + + "{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testTemplate() throws IOException { + Version templateVersion = Version.fromString("1.2.3"); + List compatibilityVersion = List.of(Version.fromString("4.5.6"), Version.fromString("7.8.9")); + + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + Template template = new Template( + "test", + "a test template", + "test use case", + List.of("operation"), + templateVersion, + compatibilityVersion, + Map.ofEntries(Map.entry("userKey", "userValue"), Map.entry("userMapKey", Map.of("nestedKey", "nestedValue"))), + Map.of("workflow", workflow) + ); + + assertEquals("test", template.name()); + assertEquals("a test template", template.description()); + assertEquals("test use case", template.useCase()); + assertEquals(List.of("operation"), template.operations()); + assertEquals(templateVersion, template.templateVersion()); + assertEquals(compatibilityVersion, template.compatibilityVersion()); + Map inputsMap = template.userInputs(); + assertEquals("userValue", inputsMap.get("userKey")); + assertEquals(Map.of("nestedKey", "nestedValue"), inputsMap.get("userMapKey")); + Workflow wf = template.workflows().get("workflow"); + assertNotNull(wf); + assertEquals("Workflow [userParams={key=value}, nodes=[A, B], edges=[A->B]]", wf.toString()); + + String json = TemplateTestJsonUtil.parseToJson(template); + assertTrue(json.startsWith(expectedPrefix)); + assertTrue(json.contains(expectedKV1)); + assertTrue(json.contains(expectedKV2)); + assertTrue(json.endsWith(expectedSuffix)); + + Template templateX = Template.parse(json); + assertEquals("test", templateX.name()); + assertEquals("a test template", templateX.description()); + assertEquals("test use case", templateX.useCase()); + assertEquals(List.of("operation"), templateX.operations()); + assertEquals(templateVersion, templateX.templateVersion()); + assertEquals(compatibilityVersion, templateX.compatibilityVersion()); + Map inputsMapX = template.userInputs(); + assertEquals("userValue", inputsMapX.get("userKey")); + assertEquals(Map.of("nestedKey", "nestedValue"), inputsMapX.get("userMapKey")); + Workflow wfX = templateX.workflows().get("workflow"); + assertNotNull(wfX); + assertEquals("Workflow [userParams={key=value}, nodes=[A, B], edges=[A->B]]", wfX.toString()); + } + + public void testExceptions() throws IOException { + String json = expectedPrefix + expectedKV1 + "," + expectedKV2 + expectedSuffix; + IOException e; + + String badTemplateField = json.replace("use_case", "badField"); + e = assertThrows(IOException.class, () -> Template.parse(badTemplateField)); + assertEquals("Unable to parse field [badField] in a template object.", e.getMessage()); + + String badVersionField = json.replace("compatibility", "badField"); + e = assertThrows(IOException.class, () -> Template.parse(badVersionField)); + assertEquals("Unable to parse field [version] in a version object.", e.getMessage()); + + String badUserInputType = json.replace("{\"nestedKey\":\"nestedValue\"}},", "[]"); + e = assertThrows(IOException.class, () -> Template.parse(badUserInputType)); + assertEquals("Unable to parse field [userMapKey] in a user inputs object.", e.getMessage()); + } + + public void testStrings() throws IOException { + Template t = Template.parse(expectedPrefix + expectedKV1 + "," + expectedKV2 + expectedSuffix); + assertTrue(t.toJson().contains(expectedPrefix)); + assertTrue(t.toJson().contains(expectedKV1)); + assertTrue(t.toJson().contains(expectedKV2)); + assertTrue(t.toJson().contains(expectedSuffix)); + + assertTrue(t.toYaml().contains("a test template")); + assertTrue(t.toString().contains("a test template")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowEdgeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowEdgeTests.java new file mode 100644 index 000000000..ffbd07bd1 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowEdgeTests.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class WorkflowEdgeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testEdge() throws IOException { + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + assertEquals("A", edgeAB.source()); + assertEquals("B", edgeAB.destination()); + assertEquals("A->B", edgeAB.toString()); + + WorkflowEdge edgeAB2 = new WorkflowEdge("A", "B"); + assertEquals(edgeAB, edgeAB2); + + WorkflowEdge edgeAC = new WorkflowEdge("A", "C"); + assertNotEquals(edgeAB, edgeAC); + + String expectedJson = "{\"source\":\"A\",\"dest\":\"B\"}"; + String json = TemplateTestJsonUtil.parseToJson(edgeAB); + assertEquals(expectedJson, json); + + WorkflowEdge edgeX = WorkflowEdge.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals("A", edgeX.source()); + assertEquals("B", edgeX.destination()); + assertEquals("A->B", edgeX.toString()); + } + + public void testExceptions() throws IOException { + String badJson = "{\"badField\":\"A\",\"dest\":\"B\"}"; + IOException e = assertThrows(IOException.class, () -> WorkflowEdge.parse(TemplateTestJsonUtil.jsonToParser(badJson))); + assertEquals("Unable to parse field [badField] in an edge object.", e.getMessage()); + + String missingJson = "{\"dest\":\"B\"}"; + e = assertThrows(IOException.class, () -> WorkflowEdge.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); + assertEquals("An edge object requires both a source and dest field.", e.getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java new file mode 100644 index 000000000..46d897b42 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Map; + +public class WorkflowNodeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testNode() throws IOException { + WorkflowNode nodeA = new WorkflowNode( + "A", + "a-type", + Map.ofEntries( + Map.entry("foo", "a string"), + Map.entry("bar", Map.of("key", "value")), + Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), + Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }) + ) + ); + assertEquals("A", nodeA.id()); + assertEquals("a-type", nodeA.type()); + Map map = nodeA.inputs(); + assertEquals("a string", (String) map.get("foo")); + assertEquals(Map.of("key", "value"), (Map) map.get("bar")); + assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); + PipelineProcessor[] pp = (PipelineProcessor[]) map.get("processors"); + assertEquals(1, pp.length); + assertEquals("test-type", pp[0].type()); + assertEquals(Map.of("key2", "value2"), pp[0].params()); + + // node equality is based only on ID + WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of("bar", "baz")); + assertEquals(nodeA, nodeA2); + + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + assertNotEquals(nodeA, nodeB); + + String json = TemplateTestJsonUtil.parseToJson(nodeA); + assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":")); + assertTrue(json.contains("\"foo\":\"a string\"")); + assertTrue(json.contains("\"baz\":[{\"A\":\"a\"},{\"B\":\"b\"}]")); + assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); + assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); + + WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals("A", nodeX.id()); + assertEquals("a-type", nodeX.type()); + Map mapX = nodeX.inputs(); + assertEquals("a string", mapX.get("foo")); + assertEquals(Map.of("key", "value"), mapX.get("bar")); + assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); + PipelineProcessor[] ppX = (PipelineProcessor[]) map.get("processors"); + assertEquals(1, ppX.length); + assertEquals("test-type", ppX[0].type()); + assertEquals(Map.of("key2", "value2"), ppX[0].params()); + } + + public void testExceptions() throws IOException { + String badJson = "{\"badField\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}"; + IOException e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(badJson))); + assertEquals("Unable to parse field [badField] in a node object.", e.getMessage()); + + String missingJson = "{\"id\":\"A\",\"inputs\":{\"foo\":\"bar\"}}"; + e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); + assertEquals("An node object requires both an id and type field.", e.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java new file mode 100644 index 000000000..db070da4b --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class WorkflowTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testWorkflow() throws IOException { + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + assertEquals(Map.of("key", "value"), workflow.userParams()); + assertEquals(List.of(nodeA, nodeB), workflow.nodes()); + assertEquals(List.of(edgeAB), workflow.edges()); + + String expectedJson = "{\"user_params\":{\"key\":\"value\"}," + + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}," + + "{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}]," + + "\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}"; + String json = TemplateTestJsonUtil.parseToJson(workflow); + assertEquals(expectedJson, json); + + XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); + Workflow workflowX = Workflow.parse(parser); + assertEquals(Map.of("key", "value"), workflowX.userParams()); + assertEquals(List.of(nodeA, nodeB), workflowX.nodes()); + assertEquals(List.of(edgeAB), workflowX.edges()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java deleted file mode 100644 index 80cecd96e..000000000 --- a/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.template; - -import org.opensearch.test.OpenSearchTestCase; - -public class ProcessSequenceEdgeTests extends OpenSearchTestCase { - - @Override - public void setUp() throws Exception { - super.setUp(); - } - - public void testEdge() { - ProcessSequenceEdge edgeAB = new ProcessSequenceEdge("A", "B"); - assertEquals("A", edgeAB.getSource()); - assertEquals("B", edgeAB.getDestination()); - assertEquals("A->B", edgeAB.toString()); - - ProcessSequenceEdge edgeAB2 = new ProcessSequenceEdge("A", "B"); - assertEquals(edgeAB, edgeAB2); - - ProcessSequenceEdge edgeAC = new ProcessSequenceEdge("A", "C"); - assertNotEquals(edgeAB, edgeAC); - } -} diff --git a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java deleted file mode 100644 index 24dcf0640..000000000 --- a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.template; - -import org.opensearch.test.OpenSearchTestCase; - -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - -import static org.opensearch.flowframework.template.TemplateParser.DESTINATION; -import static org.opensearch.flowframework.template.TemplateParser.EDGES; -import static org.opensearch.flowframework.template.TemplateParser.NODES; -import static org.opensearch.flowframework.template.TemplateParser.NODE_ID; -import static org.opensearch.flowframework.template.TemplateParser.SOURCE; -import static org.opensearch.flowframework.template.TemplateParser.WORKFLOW; - -public class TemplateParserTests extends OpenSearchTestCase { - - private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; - private static final String CYCLE_DETECTED = "Cycle detected:"; - - // Input JSON generators - private static String node(String id) { - return "{\"" + NODE_ID + "\": \"" + id + "\"}"; - } - - private static String edge(String sourceId, String destId) { - return "{\"" + SOURCE + "\": \"" + sourceId + "\", \"" + DESTINATION + "\": \"" + destId + "\"}"; - } - - private static String workflow(List nodes, List edges) { - return "{\"" + WORKFLOW + "\": {" + arrayField(NODES, nodes) + ", " + arrayField(EDGES, edges) + "}}"; - } - - private static String arrayField(String fieldName, List objects) { - return "\"" + fieldName + "\": [" + objects.stream().collect(Collectors.joining(", ")) + "]"; - } - - // Output list elements - private static ProcessNode expectedNode(String id) { - return new ProcessNode(id, null, null); - } - - // Less verbose parser - private static List parse(String json) { - return TemplateParser.parseJsonGraphToSequence(json, Collections.emptyMap()); - } - - @Override - public void setUp() throws Exception { - super.setUp(); - } - - public void testOrdering() { - List workflow; - - workflow = parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("C", "B"), edge("B", "A")))); - assertEquals(0, workflow.indexOf(expectedNode("C"))); - assertEquals(1, workflow.indexOf(expectedNode("B"))); - assertEquals(2, workflow.indexOf(expectedNode("A"))); - - workflow = parse( - workflow( - List.of(node("A"), node("B"), node("C"), node("D")), - List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("C", "D")) - ) - ); - assertEquals(0, workflow.indexOf(expectedNode("A"))); - int b = workflow.indexOf(expectedNode("B")); - int c = workflow.indexOf(expectedNode("C")); - assertTrue(b == 1 || b == 2); - assertTrue(c == 1 || c == 2); - assertEquals(3, workflow.indexOf(expectedNode("D"))); - - workflow = parse( - workflow( - List.of(node("A"), node("B"), node("C"), node("D"), node("E")), - List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("D", "E"), edge("C", "E")) - ) - ); - assertEquals(0, workflow.indexOf(expectedNode("A"))); - b = workflow.indexOf(expectedNode("B")); - c = workflow.indexOf(expectedNode("C")); - int d = workflow.indexOf(expectedNode("D")); - assertTrue(b == 1 || b == 2); - assertTrue(c == 1 || c == 2); - assertTrue(d == 2 || d == 3); - assertEquals(4, workflow.indexOf(expectedNode("E"))); - } - - public void testCycles() { - Exception ex; - - ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); - assertEquals("Edge connects node A to itself.", ex.getMessage()); - - ex = assertThrows( - IllegalArgumentException.class, - () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) - ); - assertEquals("Edge connects node B to itself.", ex.getMessage()); - - ex = assertThrows( - IllegalArgumentException.class, - () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) - ); - assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); - - ex = assertThrows( - IllegalArgumentException.class, - () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) - ); - assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); - assertTrue(ex.getMessage().contains("B->C")); - assertTrue(ex.getMessage().contains("C->B")); - - ex = assertThrows( - IllegalArgumentException.class, - () -> parse( - workflow( - List.of(node("A"), node("B"), node("C"), node("D")), - List.of(edge("A", "B"), edge("B", "C"), edge("C", "D"), edge("D", "B")) - ) - ) - ); - assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); - assertTrue(ex.getMessage().contains("B->C")); - assertTrue(ex.getMessage().contains("C->D")); - assertTrue(ex.getMessage().contains("D->B")); - } - - public void testNoEdges() { - Exception ex = assertThrows( - IllegalArgumentException.class, - () -> parse(workflow(Collections.emptyList(), Collections.emptyList())) - ); - assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); - - assertEquals(List.of(expectedNode("A")), parse(workflow(List.of(node("A")), Collections.emptyList()))); - - List workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); - assertEquals(2, workflow.size()); - assertTrue(workflow.contains(expectedNode("A"))); - assertTrue(workflow.contains(expectedNode("B"))); - } -} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java similarity index 89% rename from src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index c5d680a94..0fdc05cbd 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -6,7 +6,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.workflow.CreateIndex; +package org.opensearch.flowframework.workflow; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; @@ -14,7 +14,6 @@ import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -59,7 +58,8 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio CreateIndexStep createIndexStep = new CreateIndexStep(client); - ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIndexStep.execute(List.of(inputData)); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); @@ -76,7 +76,8 @@ public void testCreateIndexStepFailure() throws ExecutionException, InterruptedE CreateIndexStep createIndexStep = new CreateIndexStep(client); - ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIndexStep.execute(List.of(inputData)); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java similarity index 92% rename from src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 286bc2de9..9dab2a8d7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -6,7 +6,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.workflow.CreateIngestPipeline; +package org.opensearch.flowframework.workflow; import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -14,8 +14,6 @@ import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.workflow.CreateIngestPipelineStep; -import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.test.OpenSearchTestCase; import java.util.List; @@ -31,6 +29,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +@SuppressWarnings("deprecation") public class CreateIngestPipelineStepTests extends OpenSearchTestCase { private WorkflowData inputData; @@ -69,7 +68,8 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); - ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); assertFalse(future.isDone()); @@ -86,7 +86,8 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); - ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); assertFalse(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java similarity index 52% rename from src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java rename to src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index d9f365708..1972d20eb 100644 --- a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -6,26 +6,31 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.template; +package org.opensearch.flowframework.workflow; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope.Scope; - -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; +import org.junit.Before; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; -@ThreadLeakScope(Scope.NONE) public class ProcessNodeTests extends OpenSearchTestCase { - @Override - public void setUp() throws Exception { - super.setUp(); + private ExecutorService executor; + + @Before + public void setup() { + executor = Executors.newFixedThreadPool(10); + } + + @After + public void cleanup() { + executor.shutdown(); } public void testNode() throws InterruptedException, ExecutionException { @@ -41,25 +46,15 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }); + }, WorkflowData.EMPTY, Collections.emptyList(), executor); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeA.input()); - assertEquals(Collections.emptySet(), nodeA.getPredecessors()); + assertEquals(Collections.emptyList(), nodeA.predecessors()); assertEquals("A", nodeA.toString()); - // TODO: This test is flaky on Windows. Disabling until thread pool is integrated - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 - // CompletableFuture f = nodeA.execute(); - // assertEquals(f, nodeA.future()); - // f.orTimeout(5, TimeUnit.SECONDS); - // assertTrue(f.isDone()); - // assertEquals(WorkflowData.EMPTY, f.get()); - - ProcessNode nodeB = new ProcessNode("B", null); - assertNotEquals(nodeA, nodeB); - - ProcessNode nodeA2 = new ProcessNode("A", null); - assertEquals(nodeA, nodeA2); + CompletableFuture f = nodeA.execute(); + assertEquals(f, nodeA.future()); + assertEquals(WorkflowData.EMPTY, f.get()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java new file mode 100644 index 000000000..1e9c8e808 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.model.TemplateTestJsonUtil; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class WorkflowProcessSorterTests extends OpenSearchTestCase { + + private static final String MUST_HAVE_AT_LEAST_ONE_NODE = "A workflow must have at least one node."; + private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; + private static final String CYCLE_DETECTED = "Cycle detected:"; + + // Wrap parser into string list + private static List parse(String json) throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); + Workflow w = Workflow.parse(parser); + return workflowProcessSorter.sortProcessNodes(w).stream().map(ProcessNode::id).collect(Collectors.toList()); + } + + private static ExecutorService executor; + private static WorkflowProcessSorter workflowProcessSorter; + + @BeforeClass + public static void setup() { + AdminClient adminClient = mock(AdminClient.class); + Client client = mock(Client.class); + when(client.admin()).thenReturn(adminClient); + + executor = Executors.newFixedThreadPool(10); + WorkflowStepFactory factory = WorkflowStepFactory.create(client); + workflowProcessSorter = WorkflowProcessSorter.create(factory, executor); + } + + @AfterClass + public static void cleanup() { + executor.shutdown(); + } + + public void testOrdering() throws IOException { + List workflow; + + workflow = parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("C", "B"), edge("B", "A")))); + assertEquals(0, workflow.indexOf("C")); + assertEquals(1, workflow.indexOf("B")); + assertEquals(2, workflow.indexOf("A")); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("C", "D")) + ) + ); + assertEquals(0, workflow.indexOf("A")); + int b = workflow.indexOf("B"); + int c = workflow.indexOf("C"); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertEquals(3, workflow.indexOf("D")); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D"), node("E")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("D", "E"), edge("C", "E")) + ) + ); + assertEquals(0, workflow.indexOf("A")); + b = workflow.indexOf("B"); + c = workflow.indexOf("C"); + int d = workflow.indexOf("D"); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertTrue(d == 2 || d == 3); + assertEquals(4, workflow.indexOf("E")); + } + + public void testCycles() { + Exception ex; + + ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); + assertEquals("Edge connects node A to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) + ); + assertEquals("Edge connects node B to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) + ); + assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->B")); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("B", "C"), edge("C", "D"), edge("D", "B")) + ) + ) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->D")); + assertTrue(ex.getMessage().contains("D->B")); + } + + public void testNoEdges() throws IOException { + List workflow; + Exception ex = assertThrows(IOException.class, () -> parse(workflow(Collections.emptyList(), Collections.emptyList()))); + assertEquals(MUST_HAVE_AT_LEAST_ONE_NODE, ex.getMessage()); + + workflow = parse(workflow(List.of(node("A")), Collections.emptyList())); + assertEquals(1, workflow.size()); + assertEquals("A", workflow.get(0)); + + workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); + assertEquals(2, workflow.size()); + assertTrue(workflow.contains("A")); + assertTrue(workflow.contains("B")); + } + + public void testExceptions() throws IOException { + Exception ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("C", "B")))) + ); + assertEquals("Edge source C does not correspond to a node.", ex.getMessage()); + + ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); + assertEquals("Edge destination C does not correspond to a node.", ex.getMessage()); + } +} diff --git a/src/test/resources/template/datademo.json b/src/test/resources/template/datademo.json deleted file mode 100644 index a1323ed2c..000000000 --- a/src/test/resources/template/datademo.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "sequence": { - "nodes": [ - { - "id": "create_index", - "index_name": "demo" - }, - { - "id": "create_another_index", - "index_name": "second_demo" - } - ], - "edges": [ - { - "source": "create_index", - "dest": "create_another_index" - } - ] - } -} diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index 38f1d0644..e27158bff 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -1,36 +1,47 @@ { - "sequence": { - "nodes": [ - { - "id": "fetch_model" - }, - { - "id": "create_ingest_pipeline" - }, - { - "id": "create_search_pipeline" - }, - { - "id": "create_neural_search_index" - } - ], - "edges": [ - { - "source": "fetch_model", - "dest": "create_ingest_pipeline" - }, - { - "source": "fetch_model", - "dest": "create_search_pipeline" - }, - { - "source": "create_ingest_pipeline", - "dest": "create_neural_search_index" - }, - { - "source": "create_search_pipeline", - "dest": "create_neural_search_index" - } - ] + "name": "demo-template", + "description": "Demonstrates workflow steps and passing around of input/output", + "user_inputs": { + "knn_index_name": "my-knn-index" + }, + "workflows": { + "demo": { + "nodes": [ + { + "id": "fetch_model", + "type": "demo_delay_3" + }, + { + "id": "create_ingest_pipeline", + "type": "demo_delay_3" + }, + { + "id": "create_search_pipeline", + "type": "demo_delay_5" + }, + { + "id": "create_neural_search_index", + "type": "demo_delay_3" + } + ], + "edges": [ + { + "source": "fetch_model", + "dest": "create_ingest_pipeline" + }, + { + "source": "fetch_model", + "dest": "create_search_pipeline" + }, + { + "source": "create_ingest_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_search_pipeline", + "dest": "create_neural_search_index" + } + ] + } } } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json new file mode 100644 index 000000000..d8443c4c6 --- /dev/null +++ b/src/test/resources/template/finaltemplate.json @@ -0,0 +1,96 @@ +{ + "name": "semantic-search", + "description": "My semantic search use case", + "use_case": "SEMANTIC_SEARCH", + "operations": [ + "PROVISION", + "INGEST", + "QUERY" + ], + "version": { + "template": "1.0.0", + "compatibility": [ + "2.9.0", + "3.0.0" + ] + }, + "user_inputs": { + "index_name": "my-knn-index", + "index_settings": {} + }, + "workflows": { + "provision": { + "nodes": [{ + "id": "create_index", + "type": "create_index", + "inputs": { + "name": "user_inputs.index_name", + "settings": "user_inputs.index_settings" + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "inputs": { + "name": "my-ingest-pipeline", + "description": "some description", + "processors": [{ + "type": "text_embedding", + "params": { + "model_id": "my-existing-model-id", + "input_field": "text_passage", + "output_field": "text_embedding" + } + }] + } + } + ], + "edges": [{ + "source": "create_index", + "dest": "create_ingest_pipeline" + }] + }, + "ingest": { + "user_params": { + "document": "doc" + }, + "nodes": [{ + "id": "ingest_index", + "type": "ingest_index", + "inputs": { + "index": "user_inputs.index_name", + "ingest_pipeline": "my-ingest-pipeline", + "document": "user_params.document" + } + }] + }, + "query": { + "user_params": { + "plaintext": "string" + }, + "nodes": [{ + "id": "transform_query", + "type": "transform_query", + "inputs": { + "template": "neural-search-template-1", + "plaintext": "user_params.plaintext" + } + }, + { + "id": "query_index", + "type": "query_index", + "inputs": { + "index": "user_inputs.index_name", + "query": "{{output-from-prev-step}}.query", + "search_request_processors": [], + "search_response_processors": [] + } + } + ], + "edges": [{ + "source": "transform_query", + "dest": "query_index" + }] + } + } +}