Skip to content

Commit

Permalink
Add WorkflowStepFactory class
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Sep 19, 2023
1 parent d1cbcbd commit a360cdb
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 61 deletions.
13 changes: 1 addition & 12 deletions src/main/java/demo/DataDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
import org.opensearch.common.io.PathUtils;
import org.opensearch.flowframework.template.ProcessNode;
import org.opensearch.flowframework.template.TemplateParser;
import org.opensearch.flowframework.workflow.WorkflowStep;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
Expand All @@ -35,14 +32,6 @@ public class DataDemo {

private static final Logger logger = LogManager.getLogger(DataDemo.class);

// This is temporary. We need a factory class to generate these workflow steps
// based on a field in the JSON.
private static Map<String, WorkflowStep> workflowMap = new HashMap<>();
static {
workflowMap.put("create_index", new CreateIndexWorkflowStep());
workflowMap.put("create_another_index", new CreateIndexWorkflowStep());
}

/**
* Demonstrate parsing a JSON graph.
*
Expand All @@ -60,7 +49,7 @@ public static void main(String[] args) {
}

logger.info("Parsing graph to sequence...");
List<ProcessNode> processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap);
List<ProcessNode> processSequence = TemplateParser.parseJsonGraphToSequence(json);
List<CompletableFuture<?>> futureList = new ArrayList<>();

for (ProcessNode n : processSequence) {
Expand Down
16 changes: 1 addition & 15 deletions src/main/java/demo/Demo.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
import org.opensearch.common.io.PathUtils;
import org.opensearch.flowframework.template.ProcessNode;
import org.opensearch.flowframework.template.TemplateParser;
import org.opensearch.flowframework.workflow.WorkflowStep;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
Expand All @@ -35,16 +32,6 @@ public class Demo {

private static final Logger logger = LogManager.getLogger(Demo.class);

// This is temporary. We need a factory class to generate these workflow steps
// based on a field in the JSON.
private static Map<String, WorkflowStep> workflowMap = new HashMap<>();
static {
workflowMap.put("fetch_model", new DemoWorkflowStep(3000));
workflowMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000));
workflowMap.put("create_search_pipeline", new DemoWorkflowStep(5000));
workflowMap.put("create_neural_search_index", new DemoWorkflowStep(2000));
}

/**
* Demonstrate parsing a JSON graph.
*
Expand All @@ -62,7 +49,7 @@ public static void main(String[] args) {
}

logger.info("Parsing graph to sequence...");
List<ProcessNode> processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap);
List<ProcessNode> processSequence = TemplateParser.parseJsonGraphToSequence(json);
List<CompletableFuture<?>> futureList = new ArrayList<>();

for (ProcessNode n : processSequence) {
Expand All @@ -78,7 +65,6 @@ public static void main(String[] args) {
predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", "))
)
);
// TODO need to handle this better, passing an argument when we start them all at the beginning is silly
futureList.add(n.execute());
}
futureList.forEach(CompletableFuture::join);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import com.google.gson.JsonObject;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;

import java.util.ArrayDeque;
import java.util.ArrayList;
Expand All @@ -36,13 +36,16 @@ public class TemplateParser {

private static final Logger logger = LogManager.getLogger(TemplateParser.class);

// Field names in the JSON. Package private for tests.
// Field names in the JSON.
// Currently package private for tests.
// These may eventually become part of the template definition in which case they might be better declared public
static final String WORKFLOW = "sequence";
static final String NODES = "nodes";
static final String NODE_ID = "id";
static final String EDGES = "edges";
static final String SOURCE = "source";
static final String DESTINATION = "dest";
static final String STEP_TYPE = "step_type";

/**
* Prevent instantiating this class.
Expand All @@ -52,10 +55,9 @@ private TemplateParser() {}
/**
* Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes.
* @param json A string containing a JSON representation of nodes and edges
* @param workflowSteps A map linking JSON node names to Java objects implementing {@link WorkflowStep}
* @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list.
*/
public static List<ProcessNode> parseJsonGraphToSequence(String json, Map<String, WorkflowStep> workflowSteps) {
public static List<ProcessNode> parseJsonGraphToSequence(String json) {
Gson gson = new Gson();
JsonObject jsonObject = gson.fromJson(json, JsonObject.class);

Expand All @@ -67,31 +69,13 @@ public static List<ProcessNode> parseJsonGraphToSequence(String json, Map<String
for (JsonElement nodeJson : graph.getAsJsonArray(NODES)) {
JsonObject nodeObject = nodeJson.getAsJsonObject();
String nodeId = nodeObject.get(NODE_ID).getAsString();
// The below steps will be replaced by a generator class that instantiates a WorkflowStep
// based on user_input data from the template.
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/41
WorkflowStep workflowStep = workflowSteps.get(nodeId);
// temporary demo POC of getting from a request to input data
// this will be refactored into something pulling from user template as part of the above issue
WorkflowData inputData = WorkflowData.EMPTY;
if (List.of("create_index", "create_another_index").contains(nodeId)) {
CreateIndexRequest request = new CreateIndexRequest(nodeObject.get("index_name").getAsString());
inputData = new WorkflowData() {

@Override
public Map<String, Object> getContent() {
// See CreateIndexRequest ParseFields for source of content keys needed
return Map.of("mappings", request.mappings(), "settings", request.settings(), "aliases", request.aliases());
}

@Override
public Map<String, String> getParams() {
// See RestCreateIndexAction for source of param keys needed
return Map.of("index", request.index());
}

};
}
String stepType = nodeObject.get(STEP_TYPE).getAsString();
WorkflowStep workflowStep = WorkflowStepFactory.get().createStep(stepType);

// TODO as part of this PR: Fetch data from the template here
WorkflowData inputData = new WorkflowData() {
// TODO override params and content from user template
};
nodes.add(new ProcessNode(nodeId, workflowStep, inputData));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import demo.CreateIndexWorkflowStep;
import demo.DemoWorkflowStep;

/**
* Generates instances implementing {@link WorkflowStep}.
*/
public class WorkflowStepFactory {

private static final WorkflowStepFactory INSTANCE = new WorkflowStepFactory();

private final Map<String, WorkflowStep> stepMap = new HashMap<>();

private WorkflowStepFactory() {
populateMap();
}

/**
* Gets the singleton instance of this class
* @return The instance of this class
*/
public static WorkflowStepFactory get() {
return INSTANCE;
}

private void populateMap() {
// TODO: These are from the demo class as placeholders
// Replace with actual implementations such as
// https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/38
// https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/44
stepMap.put("create_index", new CreateIndexWorkflowStep());
stepMap.put("fetch_model", new DemoWorkflowStep(3000));
stepMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000));
stepMap.put("create_search_pipeline", new DemoWorkflowStep(5000));
stepMap.put("create_neural_search_index", new DemoWorkflowStep(2000));

// Use until all the actual implementations are ready
stepMap.put("placeholder", new WorkflowStep() {
@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture<WorkflowData> future = new CompletableFuture<>();
future.complete(WorkflowData.EMPTY);
return future;
}

@Override
public String getName() {
return "placeholder";
}
});
}

/**
* Create a new instance of a {@link WorkflowStep}.
* @param type The type of instance to create
* @return an instance of the specified type
*/
public WorkflowStep createStep(String type) {
if (stepMap.containsKey(type)) {
return stepMap.get(type);
}
// TODO: replace this with a FlowFrameworkException
// https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/43
throw new UnsupportedOperationException("No workflow steps of type [" + type + "] are implemented.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private static ProcessNode expectedNode(String id) {

// Less verbose parser
private static List<ProcessNode> parse(String json) {
return TemplateParser.parseJsonGraphToSequence(json, Collections.emptyMap());
return TemplateParser.parseJsonGraphToSequence(json);
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions src/test/resources/template/datademo.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"nodes": [
{
"id": "create_index",
"step_type": "create_index",
"index_name": "demo"
},
{
"id": "create_another_index",
"step_type": "create_index",
"index_name": "second_demo"
}
],
Expand Down
12 changes: 8 additions & 4 deletions src/test/resources/template/demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
"sequence": {
"nodes": [
{
"id": "fetch_model"
"id": "fetch_model",
"step_type": "fetch_model"
},
{
"id": "create_ingest_pipeline"
"id": "create_ingest_pipeline",
"step_type": "create_ingest_pipeline"
},
{
"id": "create_search_pipeline"
"id": "create_search_pipeline",
"step_type": "create_search_pipeline"
},
{
"id": "create_neural_search_index"
"id": "create_neural_search_index",
"step_type": "create_neural_search_index"
}
],
"edges": [
Expand Down

0 comments on commit a360cdb

Please sign in to comment.