From 8cebc8c98cf8acb6c8395dfb1148d26f688754dd Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 19 Oct 2023 13:02:37 -0700 Subject: [PATCH] Added test for create connector Signed-off-by: Owais Kazi --- src/main/java/demo/Demo.java | 4 +++- src/main/java/demo/TemplateParseDemo.java | 4 +++- .../flowframework/FlowFrameworkPlugin.java | 4 +++- .../workflow/CreateConnectorStep.java | 18 +++++++++--------- .../workflow/DeployModelStep.java | 14 +++++--------- .../workflow/RegisterModelStep.java | 14 +++++--------- .../workflow/WorkflowStepFactory.java | 12 +++++++----- .../workflow/CreateConnectorStepTests.java | 11 ++++------- .../workflow/DeployModelStepTests.java | 18 +++++------------- .../workflow/RegisterModelStepTests.java | 17 +++++------------ .../workflow/WorkflowProcessSorterTests.java | 4 +++- 11 files changed, 52 insertions(+), 68 deletions(-) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 910f22b14..e4d2aa8f8 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index a2d0db443..e284764da 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 0bac15c61..b9a35c083 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -34,6 +34,7 @@ import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; @@ -76,7 +77,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 71096d4e0..afae055bc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -10,10 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -37,24 +35,22 @@ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "create_connector"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public CreateConnectorStep(Client client) { - this.client = client; + public CreateConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override public CompletableFuture execute(List data) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override @@ -96,12 +92,16 @@ public void onFailure(Exception e) { break; case PROTOCOL_FIELD: protocol = (String) content.get(PROTOCOL_FIELD); + break; case PARAMETERS_FIELD: parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + break; case CREDENTIALS_FIELD: credentials = (Map) content.get(CREDENTIALS_FIELD); + break; case ACTIONS_FIELD: actions = (List) content.get(ACTIONS_FIELD); + break; } } @@ -118,7 +118,7 @@ public void onFailure(Exception e) { .actions(actions) .build(); - machineLearningNodeClient.createConnector(mlInput, actionListener); + mlClient.createConnector(mlInput, actionListener); } return createConnectorFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index e4c9b1a14..07558fe0c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,9 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -28,15 +26,15 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "deploy_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public DeployModelStep(Client client) { - this.client = client; + public DeployModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -44,8 +42,6 @@ public CompletableFuture execute(List data) { CompletableFuture deployModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @@ -70,7 +66,7 @@ public void onFailure(Exception e) { break; } } - machineLearningNodeClient.deploy(modelId, actionListener); + mlClient.deploy(modelId, actionListener); return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index cf880626b..bdbda66e4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,9 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; @@ -44,16 +42,16 @@ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "register_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public RegisterModelStep(Client client) { - this.client = client; + public RegisterModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -61,8 +59,6 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -139,7 +135,7 @@ public void onFailure(Exception e) { .connectorId(connectorId) .build(); - machineLearningNodeClient.register(mlInput, actionListener); + mlClient.register(mlInput, actionListener); } return registerModelFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index fdb82ef0b..be52a5fcd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,6 +10,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; import java.util.List; @@ -32,15 +33,16 @@ public class WorkflowStepFactory { * @param client The OpenSearch client steps can use */ - public WorkflowStepFactory(ClusterService clusterService, Client client) { - populateMap(clusterService, client); + public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + populateMap(clusterService, client, mlClient); } - private void populateMap(ClusterService clusterService, Client client) { + private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 98c31470e..65329661f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,7 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -20,7 +19,6 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -32,9 +30,6 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - @Mock ActionListener registerModelActionListener; @@ -67,12 +62,12 @@ public void setUp() throws Exception { public void testCreateConnector() throws IOException { String connectorId = "connect"; - CreateConnectorStep createConnectorStep = new CreateConnectorStep(client); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); actionListener.onResponse(output); return null; @@ -82,6 +77,8 @@ public void testCreateConnector() throws IOException { verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + assertTrue(future.isDone()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fc7c695f8..e32a7c75f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,35 +10,30 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -50,8 +45,6 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - nodeClient = new NoOpNodeClient("xyz"); - } public void testDeployModel() { @@ -60,13 +53,13 @@ public void testDeployModel() { String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModelStep deployModel = new DeployModelStep(nodeClient); + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; @@ -74,10 +67,9 @@ public void testDeployModel() { CompletableFuture future = deployModel.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index b1a2b2fc0..1e8026a15 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -10,7 +10,6 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -20,28 +19,24 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock ActionListener registerModelActionListener; @@ -70,7 +65,6 @@ public void setUp() throws Exception { ) ); - nodeClient = new NoOpNodeClient("xyz"); } public void testRegisterModel() throws ExecutionException, InterruptedException { @@ -85,12 +79,12 @@ public void testRegisterModel() throws ExecutionException, InterruptedException .connectorId("abcdefgh") .build(); - RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); actionListener.onResponse(output); return null; @@ -98,10 +92,9 @@ public void testRegisterModel() throws ExecutionException, InterruptedException CompletableFuture future = registerModelStep.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e8ada0e15..f728dd7b1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -60,11 +61,12 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); }