diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index d685cbeaa..b00857ff6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -74,22 +74,28 @@ public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndi // TODO: need to add retry conflicts here @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { + String workflowId = currentNodeInputs.getWorkflowId(); createConnectorFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), - data.get(0).getWorkflowId() + workflowId, + currentNodeInputs.getNodeId() ) ); try { logger.info("Created connector successfully"); - String workflowId = data.get(0).getWorkflowId(); String workflowStepName = getName(); ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -136,6 +142,12 @@ public void onFailure(Exception e) { Map credentials = Collections.emptyMap(); List actions = Collections.emptyList(); + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + try { for (WorkflowData workflowData : data) { for (Entry entry : workflowData.getContent().entrySet()) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f3a82b26c..f443e9c2c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -19,6 +19,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -54,14 +55,25 @@ public CreateIndexStep(ClusterService clusterService, Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); + future.complete( + new WorkflowData( + Map.of(INDEX_NAME, createIndexResponse.index()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); } @Override @@ -75,6 +87,12 @@ public void onFailure(Exception e) { String type = null; Settings settings = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); index = (String) content.get(INDEX_NAME); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index a63a800fd..77dae29eb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -20,6 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -59,7 +60,12 @@ public CreateIngestPipelineStep(Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture createIngestPipelineFuture = new CompletableFuture<>(); @@ -71,6 +77,12 @@ public CompletableFuture execute(List data) { String outputFieldName = null; BytesReference configuration = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // Extract required content from workflow data and generate the ingest pipeline configuration for (WorkflowData workflowData : data) { @@ -126,7 +138,11 @@ public CompletableFuture execute(List data) { // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead createIngestPipelineFuture.complete( - new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + new WorkflowData( + Map.of(PIPELINE_ID, putPipelineRequest.getId()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 8ce89176c..aa6768605 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -17,6 +17,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -41,7 +42,12 @@ public DeployModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture deployModelFuture = new CompletableFuture<>(); @@ -52,7 +58,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { deployModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -66,6 +73,12 @@ public void onFailure(Exception e) { String modelId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { if (workflowData.getContent().containsKey(MODEL_ID)) { modelId = (String) workflowData.getContent().get(MODEL_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index bb57adbae..018783b19 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -50,12 +51,23 @@ public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLe } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture getMLTaskFuture = new CompletableFuture<>(); String taskId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); for (Entry entry : content.entrySet()) { @@ -73,7 +85,7 @@ public CompletableFuture execute(List data) { logger.error("Failed to retrieve ML Task"); getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); } else { - retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0); + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId(), getMLTaskFuture, taskId, 0); } return getMLTaskFuture; @@ -87,11 +99,18 @@ public String getName() { /** * Retryable GetMLTask * @param workflowId the workflow id + * @param nodeId the node id * @param getMLTaskFuture the workflow step future * @param taskId the ml task id * @param retries the current number of request retries */ - protected void retryableGetMlTask(String workflowId, CompletableFuture getMLTaskFuture, String taskId, int retries) { + protected void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture getMLTaskFuture, + String taskId, + int retries + ) { mlClient.getTask(taskId, ActionListener.wrap(response -> { if (response.getState() != MLTaskState.COMPLETED) { throw new IllegalStateException("MLTask is not yet completed"); @@ -103,7 +122,8 @@ protected void retryableGetMlTask(String workflowId, CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); @@ -67,7 +72,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -85,6 +91,12 @@ public void onFailure(Exception e) { AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index 098c5626c..bbf325e46 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -21,7 +21,12 @@ public class NoOpStep implements WorkflowStep { public static final String NAME = "noop"; @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { return CompletableFuture.completedFuture(WorkflowData.EMPTY); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 729043074..c6bdcc6a5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -14,7 +14,7 @@ import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -152,10 +152,10 @@ public CompletableFuture execute() { logger.info("Starting {}.", this.id); // get the input data from predecessor(s) - List input = new ArrayList(); - input.add(this.input); + Map inputMap = new HashMap<>(); for (CompletableFuture cf : predFutures) { - input.add(cf.get()); + WorkflowData wd = cf.get(); + inputMap.put(wd.getNodeId(), wd); } ScheduledCancellable delayExec = null; @@ -167,7 +167,12 @@ public CompletableFuture execute() { }, this.nodeTimeout, ThreadPool.Names.SAME); } // record start time for this step. - CompletableFuture stepFuture = this.workflowStep.execute(input); + CompletableFuture stepFuture = this.workflowStep.execute( + this.id, + this.input, + inputMap, + this.previousNodeInputs + ); // If completed exceptionally, this is a no-op future.complete(stepFuture.get()); // record end time passing workflow steps diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index ad6cbff8f..27aa5e537 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -64,7 +65,12 @@ public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); @@ -78,7 +84,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -102,6 +109,12 @@ public void onFailure(Exception e) { String allConfig = null; String url = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index de889b720..e41323a14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -55,7 +56,12 @@ public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); @@ -69,7 +75,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -87,6 +94,12 @@ public void onFailure(Exception e) { String description = null; String connectorId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 for (WorkflowData workflowData : data) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 4f62885e9..a0d901f74 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -28,18 +28,21 @@ public class WorkflowData { @Nullable private String workflowId; + @Nullable + private String nodeId; private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap(), ""); + this(Collections.emptyMap(), Collections.emptyMap(), null, null); } /** * Instantiate this object with content and empty params. * @param content The content map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, @Nullable String workflowId) { - this(content, Collections.emptyMap(), workflowId); + public WorkflowData(Map content, @Nullable String workflowId, @Nullable String nodeId) { + this(content, Collections.emptyMap(), workflowId, nodeId); } /** @@ -47,11 +50,13 @@ public WorkflowData(Map content, @Nullable String workflowId) { * @param content The content map * @param params The params map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, Map params, @Nullable String workflowId) { + public WorkflowData(Map content, Map params, @Nullable String workflowId, @Nullable String nodeId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); this.workflowId = workflowId; + this.nodeId = nodeId; } /** @@ -72,11 +77,20 @@ public Map getParams() { }; /** - * Returns the workflowId associated with this workflow. + * Returns the workflowId associated with this data. * @return the workflowId of this data. */ @Nullable public String getWorkflowId() { return this.workflowId; }; + + /** + * Returns the nodeId associated with this data. + * @return the nodeId of this data. + */ + @Nullable + public String getNodeId() { + return this.nodeId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index da0705eea..3e8b77f9d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -70,7 +70,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId) Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId, node.id()); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 1f5545cdf..f106ee652 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -19,11 +19,19 @@ public interface WorkflowStep { /** * Triggers the actual processing of the building block. - * @param data representing input params and content, or output content of previous steps. The first element of the list is data (if any) provided from parsing the template, and may be {@link WorkflowData#EMPTY}. + * @param currentNodeId The id of the node executing this step + * @param currentNodeInputs Input params and content for this node, from workflow parsing + * @param previousNodeInputs Input params for this node that come from previous steps + * @param outputs WorkflowData content of previous steps. * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. * @throws IOException on a failure. */ - CompletableFuture execute(List data) throws IOException; + CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException; /** * Gets the name of the workflow step. diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index e26fdf0c4..de3add996 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -71,7 +71,8 @@ public void setUp() throws Exception { Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -90,7 +91,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); @@ -111,7 +117,12 @@ public void testCreateConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 67cb6cb9b..8be5c5787 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -24,8 +24,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id", "test-node-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -91,7 +91,12 @@ public void setUp() throws Exception { public void testCreateIndexStep() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new CreateIndexResponse(true, true, "demo")); @@ -106,7 +111,12 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio public void testCreateIndexStepFailure() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 194c80eb0..f0a970758 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -16,7 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -51,11 +51,12 @@ public void setUp() throws Exception { Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") ), - "test-id" + "test-id", + "test-node-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id", "test-node-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -71,7 +72,12 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -89,7 +95,12 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -115,10 +126,16 @@ public void testMissingData() throws InterruptedException { Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") ), - "test-id" + "test-id", + "test-node-id" ); - CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); + CompletableFuture future = createIngestPipelineStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone() && future.isCompletedExceptionally()); ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fd856b945..670933373 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -19,7 +19,7 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); MockitoAnnotations.openMocks(this); @@ -67,7 +67,12 @@ public void testDeployModel() throws ExecutionException, InterruptedException { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); @@ -87,7 +92,12 @@ public void testDeployModelFailure() { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index efb59d42a..bd62ddfc7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -70,7 +70,7 @@ public void setUp() throws Exception { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.getMLTaskStep = spy(new GetMLTaskStep(testMaxRetrySetting, clusterService, mlNodeClient)); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id", "test-node-id"); } public void testGetMLTaskSuccess() throws Exception { @@ -85,7 +85,12 @@ public void testGetMLTaskSuccess() throws Exception { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).getTask(any(), any()); @@ -102,7 +107,12 @@ public void testGetMLTaskFailure() { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -111,7 +121,12 @@ public void testGetMLTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = this.getMLTaskStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.getMLTaskStep.execute( + "nodeID", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index f763c8005..bc914baa7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -29,7 +29,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -54,7 +53,8 @@ public void setUp() throws Exception { Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -75,7 +75,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -97,7 +102,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -111,7 +121,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt public void testRegisterModelGroupWithNoName() throws IOException { ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); - CompletableFuture future = modelGroupStep.execute(List.of(inputDataWithNoName)); + CompletableFuture future = modelGroupStep.execute( + inputDataWithNoName.getNodeId(), + inputDataWithNoName, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 6c03cd87d..1782375cc 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -19,7 +19,12 @@ public class NoOpStepTests extends OpenSearchTestCase { public void testNoOpStep() throws IOException { NoOpStep noopStep = new NoOpStep(); assertEquals(NoOpStep.NAME, noopStep.getName()); - CompletableFuture future = noopStep.execute(Collections.emptyList()); + CompletableFuture future = noopStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertFalse(future.isCompletedExceptionally()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 6aae139e4..f50250ea5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -56,9 +56,14 @@ public void testNode() throws InterruptedException, ExecutionException { // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); return f; } @@ -68,7 +73,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id", "test-node-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -78,6 +83,7 @@ public String getName() { assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); assertEquals("test-id", nodeA.input().getWorkflowId()); + assertEquals("test-node-id", nodeA.input().getNodeId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); @@ -91,7 +97,12 @@ public void testNodeNoTimeout() throws InterruptedException, ExecutionException // Tests where execute finishes before timeout ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule( () -> future.complete(WorkflowData.EMPTY), @@ -121,7 +132,12 @@ public void testNodeTimeout() throws InterruptedException, ExecutionException { // Tests where execute finishes after timeout ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC); return future; @@ -148,7 +164,12 @@ public void testExceptions() { // Tests where a predecessor future completed exceptionally ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); f.complete(WorkflowData.EMPTY); return f; diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index bd40c50ad..b7b47de46 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -70,7 +70,8 @@ public void setUp() throws Exception { Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -87,7 +88,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); assertTrue(future.isDone()); @@ -105,7 +111,12 @@ public void testRegisterLocalModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -114,7 +125,12 @@ public void testRegisterLocalModelFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = registerLocalModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index e60707f67..cde194326 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -18,7 +18,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -55,7 +55,8 @@ public void setUp() throws Exception { Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -72,7 +73,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); @@ -90,7 +96,12 @@ public void testRegisterRemoteModelFailure() { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -100,7 +111,12 @@ public void testRegisterRemoteModelFailure() { } public void testMissingInputs() { - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.registerRemoteModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index 8a4a1fda9..39023c6b4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,14 +26,17 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); + WorkflowData contentOnly = new WorkflowData(expectedContent, null, null); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); + assertNull(contentOnly.getWorkflowId()); + assertNull(contentOnly.getNodeId()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123", "test-node-id"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); assertEquals("test-id-123", contentAndParams.getWorkflowId()); + assertEquals("test-node-id", contentAndParams.getNodeId()); } }