Skip to content

Commit

Permalink
Added test for create connector
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 19, 2023
1 parent d65fd5e commit 8cebc8c
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 68 deletions.
4 changes: 3 additions & 1 deletion src/main/java/demo/Demo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);

ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
Expand Down Expand Up @@ -76,7 +77,8 @@ public Collection<Object> createComponents(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

// TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
Expand All @@ -37,24 +35,22 @@ public class CreateConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class);

private Client client;
private MachineLearningNodeClient mlClient;

static final String NAME = "create_connector";

/**
* Instantiate this class
* @param client client to instantiate MLClient
* @param mlClient client to instantiate MLClient
*/
public CreateConnectorStep(Client client) {
this.client = client;
public CreateConnectorStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
CompletableFuture<WorkflowData> createConnectorFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLCreateConnectorResponse> actionListener = new ActionListener<>() {

@Override
Expand Down Expand Up @@ -96,12 +92,16 @@ public void onFailure(Exception e) {
break;
case PROTOCOL_FIELD:
protocol = (String) content.get(PROTOCOL_FIELD);
break;
case PARAMETERS_FIELD:
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD));
break;
case CREDENTIALS_FIELD:
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD);
break;
case ACTIONS_FIELD:
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD);
break;
}

}
Expand All @@ -118,7 +118,7 @@ public void onFailure(Exception e) {
.actions(actions)
.build();

machineLearningNodeClient.createConnector(mlInput, actionListener);
mlClient.createConnector(mlInput, actionListener);
}

return createConnectorFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

Expand All @@ -28,24 +26,22 @@
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
private MachineLearningNodeClient mlClient;
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
* @param mlClient client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
public DeployModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
Expand All @@ -70,7 +66,7 @@ public void onFailure(Exception e) {
break;
}
}
machineLearningNodeClient.deploy(modelId, actionListener);
mlClient.deploy(modelId, actionListener);
return deployModelFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
Expand Down Expand Up @@ -44,25 +42,23 @@ public class RegisterModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private Client client;
private MachineLearningNodeClient mlClient;

static final String NAME = "register_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
* @param mlClient client to instantiate MLClient
*/
public RegisterModelStep(Client client) {
this.client = client;
public RegisterModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Expand Down Expand Up @@ -139,7 +135,7 @@ public void onFailure(Exception e) {
.connectorId(connectorId)
.build();

machineLearningNodeClient.register(mlInput, actionListener);
mlClient.register(mlInput, actionListener);
}

return registerModelFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.util.HashMap;
import java.util.List;
Expand All @@ -32,15 +33,16 @@ public class WorkflowStepFactory {
* @param client The OpenSearch client steps can use
*/

public WorkflowStepFactory(ClusterService clusterService, Client client) {
populateMap(clusterService, client);
public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) {
populateMap(clusterService, client, mlClient);
}

private void populateMap(ClusterService clusterService, Client client) {
private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) {
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client));
stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client));
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(client));
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
Expand All @@ -20,7 +19,6 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand All @@ -32,9 +30,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private Client client;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

Expand Down Expand Up @@ -67,12 +62,12 @@ public void setUp() throws Exception {
public void testCreateConnector() throws IOException {

String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(client);
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
actionListener.onResponse(output);
return null;
Expand All @@ -82,6 +77,8 @@ public void testCreateConnector() throws IOException {

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

assertTrue(future.isDone());

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,30 @@

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.client.NoOpNodeClient;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class DeployModelStepTests extends OpenSearchTestCase {

private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private NodeClient nodeClient;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand All @@ -50,8 +45,6 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);

nodeClient = new NoOpNodeClient("xyz");

}

public void testDeployModel() {
Expand All @@ -60,24 +53,23 @@ public void testDeployModel() {
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;

DeployModelStep deployModel = new DeployModelStep(nodeClient);
DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(1);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = deployModel.execute(List.of(inputData));

// TODO: Find a way to verify the below
// verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());
verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
assertTrue(future.isDone());

}
}
Loading

0 comments on commit 8cebc8c

Please sign in to comment.