Skip to content

Commit

Permalink
Added Registering a group model
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 26, 2023
1 parent 6ee3d53 commit c17327c
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 13 deletions.
14 changes: 7 additions & 7 deletions src/main/java/org/opensearch/flowframework/model/Template.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@
public class Template implements ToXContentObject {

/** The template field name for template name */
public static final String NAME_FIELD = "name";
private static final String NAME_FIELD = "name";
/** The template field name for template description */
public static final String DESCRIPTION_FIELD = "description";
private static final String DESCRIPTION_FIELD = "description";
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
private static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version information */
public static final String VERSION_FIELD = "version";
private static final String VERSION_FIELD = "version";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
private static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
private static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";
private static final String WORKFLOWS_FIELD = "workflows";

private final String name;
private final String description;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a model group
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public ModelGroupStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {

CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) {
logger.info("Model group registration successful");
registerModelGroupFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model group");
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelGroupName = null;
String description = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelGroupName = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
default:
break;
}
}
}

if (Stream.of(modelGroupName, description).allMatch(x -> x != null)) {
MLRegisterModelGroupInput mlInput = MLRegisterModelGroupInput.builder().name(modelGroupName).description(description).build();

mlClient.registerModelGroup(mlInput, actionListener);
} else {
registerModelGroupFuture.completeExceptionally(

Check warning on line 101 in src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java#L101

Added line #L101 was not covered by tests
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

return registerModelGroupFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 111 in src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java#L111

Added line #L111 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class ModelGroupStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Map.ofEntries(Map.entry("name", "test"), Map.entry("description", "description")));
}

public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException {
String modelGroupId = "model_group_id";
String status = MLTaskState.CREATED.name();

ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1);
MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData));

verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

assertTrue(future.isDone());
assertEquals(modelGroupId, future.get().getContent().get("model_group_id"));
assertEquals(status, future.get().getContent().get("model_group_status"));

}

public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException {
ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to register model group", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData));

verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to register model group", ex.getCause().getMessage());

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
public class RegisterModelStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLRegisterModelResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand Down

0 comments on commit c17327c

Please sign in to comment.