Skip to content

Commit

Permalink
Added register a group model step (#118)
Browse files Browse the repository at this point in the history
* Added Registering a group model

Signed-off-by: Owais Kazi <[email protected]>

* Updated create connector test

Signed-off-by: Owais Kazi <[email protected]>

* Added optional field and handled them for creating the builder

Signed-off-by: Owais Kazi <[email protected]>

* Added another test for name check

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Moved common fields to CommonValue

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 30, 2023
1 parent 6ee3d53 commit bcd53e1
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ private CommonValue() {}
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

/** The transport action name prefix */
public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/";
Expand Down Expand Up @@ -55,7 +63,7 @@ private CommonValue() {}
/** Model Group Id field */
public static final String MODEL_GROUP_ID = "model_group_id";
/** Description field */
public static final String DESCRIPTION = "description";
public static final String DESCRIPTION_FIELD = "description";
/** Connector Id field */
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
Expand All @@ -72,4 +80,10 @@ private CommonValue() {}
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
public static final String BACKEND_ROLES_FIELD = "backend_roles";
/** Access mode for the model */
public static final String MODEL_ACCESS_MODE = "access_mode";
/** Add all backend roles */
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles";
}
22 changes: 7 additions & 15 deletions src/main/java/org/opensearch/flowframework/model/Template.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,19 @@
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.COMPATIBILITY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TEMPLATE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOWS_FIELD;

/**
* The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API.
*/
public class Template implements ToXContentObject {

/** The template field name for template name */
public static final String NAME_FIELD = "name";
/** The template field name for template description */
public static final String DESCRIPTION_FIELD = "description";
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version information */
public static final String VERSION_FIELD = "version";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

private final String name;
private final String description;
private final String useCase; // probably an ENUM actually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +33,7 @@

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
Expand Down Expand Up @@ -85,7 +86,7 @@ public void onFailure(Exception e) {
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;
List<ConnectorAction> actions = new ArrayList<>();

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
Expand All @@ -95,8 +96,8 @@ public void onFailure(Exception e) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES;
import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a model group
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public ModelGroupStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {

CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) {
logger.info("Model group registration successful");
registerModelGroupFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model group");
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelGroupName = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelGroupName = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case BACKEND_ROLES_FIELD:
backendRoles = getBackendRoles(content);
break;
case MODEL_ACCESS_MODE:
modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE);
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES);
break;
default:
break;
}
}
}

if (modelGroupName == null) {
registerModelGroupFuture.completeExceptionally(
new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST)
);
} else {
MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
if (description != null) {
builder.description(description);
}
if (!backendRoles.isEmpty()) {
builder.backendRoles(backendRoles);
}
if (modelAccessMode != null) {
builder.modelAccessMode(modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.isAddAllBackendRoles(isAddAllBackendRoles);
}
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);
}

return registerModelGroupFuture;
}

@Override
public String getName() {
return NAME;
}

@SuppressWarnings("unchecked")
private List<String> getBackendRoles(Map<String, Object> content) {
return (List<String>) content.get(BACKEND_ROLES_FIELD);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
Expand Down Expand Up @@ -114,8 +114,8 @@ public void onFailure(Exception e) {
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -34,9 +35,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand All @@ -49,6 +47,10 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);

ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = "foot.test";

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Expand All @@ -57,7 +59,20 @@ public void setUp() throws Exception {
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry("actions", List.of("actions"))
Map.entry(
"actions",
List.of(
new ConnectorAction(
actionType,
method,
url,
null,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }",
null,
null
)
)
)
)
);

Expand Down
Loading

0 comments on commit bcd53e1

Please sign in to comment.