diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ce0df435a..94668a24c 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -27,6 +27,14 @@ private CommonValue() {} public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; /** Global Context index mapping version */ public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + /** The template field name for template use case */ + public static final String USE_CASE_FIELD = "use_case"; + /** The template field name for template version */ + public static final String TEMPLATE_FIELD = "template"; + /** The template field name for template compatibility with OpenSearch versions */ + public static final String COMPATIBILITY_FIELD = "compatibility"; + /** The template field name for template workflows */ + public static final String WORKFLOWS_FIELD = "workflows"; /** The transport action name prefix */ public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; @@ -55,7 +63,7 @@ private CommonValue() {} /** Model Group Id field */ public static final String MODEL_GROUP_ID = "model_group_id"; /** Description field */ - public static final String DESCRIPTION = "description"; + public static final String DESCRIPTION_FIELD = "description"; /** Connector Id field */ public static final String CONNECTOR_ID = "connector_id"; /** Model format field */ @@ -72,4 +80,10 @@ private CommonValue() {} public static final String CREDENTIALS_FIELD = "credentials"; /** Connector actions field */ public static final String ACTIONS_FIELD = "actions"; + /** Backend roles for the model */ + public static final String BACKEND_ROLES_FIELD = "backend_roles"; + /** Access mode for the model */ + public static final String MODEL_ACCESS_MODE = "access_mode"; + /** Add all backend roles */ + public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a1a526153..6dedb5db7 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -25,27 +25,19 @@ import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.COMPATIBILITY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TEMPLATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.USE_CASE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOWS_FIELD; /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. */ public class Template implements ToXContentObject { - /** The template field name for template name */ - public static final String NAME_FIELD = "name"; - /** The template field name for template description */ - public static final String DESCRIPTION_FIELD = "description"; - /** The template field name for template use case */ - public static final String USE_CASE_FIELD = "use_case"; - /** The template field name for template version information */ - public static final String VERSION_FIELD = "version"; - /** The template field name for template version */ - public static final String TEMPLATE_FIELD = "template"; - /** The template field name for template compatibility with OpenSearch versions */ - public static final String COMPATIBILITY_FIELD = "compatibility"; - /** The template field name for template workflows */ - public static final String WORKFLOWS_FIELD = "workflows"; - private final String name; private final String description; private final String useCase; // probably an ENUM actually diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index dff6ac22a..e17bf2aa0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -23,6 +23,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,7 +33,7 @@ import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; -import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; @@ -85,7 +86,7 @@ public void onFailure(Exception e) { String protocol = null; Map parameters = new HashMap<>(); Map credentials = new HashMap<>(); - List actions = null; + List actions = new ArrayList<>(); for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); @@ -95,8 +96,8 @@ public void onFailure(Exception e) { case NAME_FIELD: name = (String) content.get(NAME_FIELD); break; - case DESCRIPTION: - description = (String) content.get(DESCRIPTION); + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); break; case VERSION_FIELD: version = (String) content.get(VERSION_FIELD); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java new file mode 100644 index 000000000..9e1010ec1 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES; +import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; + +/** + * Step to register a model group + */ +public class ModelGroupStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "model_group"; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public ModelGroupStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) throws IOException { + + CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { + logger.info("Model group registration successful"); + registerModelGroupFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), + Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) + ) + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register model group"); + registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String modelGroupName = null; + String description = null; + List backendRoles = new ArrayList<>(); + AccessMode modelAccessMode = null; + Boolean isAddAllBackendRoles = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + modelGroupName = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); + break; + case BACKEND_ROLES_FIELD: + backendRoles = getBackendRoles(content); + break; + case MODEL_ACCESS_MODE: + modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE); + break; + case ADD_ALL_BACKEND_ROLES: + isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES); + break; + default: + break; + } + } + } + + if (modelGroupName == null) { + registerModelGroupFuture.completeExceptionally( + new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST) + ); + } else { + MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder(); + builder.name(modelGroupName); + if (description != null) { + builder.description(description); + } + if (!backendRoles.isEmpty()) { + builder.backendRoles(backendRoles); + } + if (modelAccessMode != null) { + builder.modelAccessMode(modelAccessMode); + } + if (isAddAllBackendRoles != null) { + builder.isAddAllBackendRoles(isAddAllBackendRoles); + } + MLRegisterModelGroupInput mlInput = builder.build(); + + mlClient.registerModelGroup(mlInput, actionListener); + } + + return registerModelGroupFuture; + } + + @Override + public String getName() { + return NAME; + } + + @SuppressWarnings("unchecked") + private List getBackendRoles(Map content) { + return (List) content.get(BACKEND_ROLES_FIELD); + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index df14d6c54..b6ae176d3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -29,7 +29,7 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; -import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; @@ -114,8 +114,8 @@ public void onFailure(Exception e) { case MODEL_CONFIG: modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); break; - case DESCRIPTION: - description = (String) content.get(DESCRIPTION); + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); break; case CONNECTOR_ID: connectorId = (String) content.get(CONNECTOR_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 48e26e5a6..7937f59db 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -40,6 +40,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 4ef3e17c3..f836067b2 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -54,5 +54,14 @@ "outputs":[ "deploy_model_status" ] + }, + "model_group": { + "inputs":[ + "name" + ], + "outputs":[ + "model_group_id", + "model_group_status" + ] } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index bf170ab9f..b54b2a27c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -12,6 +12,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.test.OpenSearchTestCase; @@ -34,9 +35,6 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock - ActionListener registerModelActionListener; - @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -49,6 +47,10 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "post"; + String url = "foot.test"; + inputData = new WorkflowData( Map.ofEntries( Map.entry("name", "test"), @@ -57,7 +59,20 @@ public void setUp() throws Exception { Map.entry("protocol", "test"), Map.entry("params", params), Map.entry("credentials", credentials), - Map.entry("actions", List.of("actions")) + Map.entry( + "actions", + List.of( + new ConnectorAction( + actionType, + method, + url, + null, + "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }", + null, + null + ) + ) + ) ) ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java new file mode 100644 index 000000000..8868b628e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.google.common.collect.ImmutableList; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class ModelGroupStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + private WorkflowData inputDataWithNoName = WorkflowData.EMPTY; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("backend_roles", ImmutableList.of("role-1")), + Map.entry("access_mode", AccessMode.PUBLIC), + Map.entry("add_all_backend_roles", false) + ) + ); + + } + + public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { + String modelGroupId = "model_group_id"; + String status = MLTaskState.CREATED.name(); + + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(modelGroupId, future.get().getContent().get("model_group_id")); + assertEquals(status, future.get().getContent().get("model_group_status")); + + } + + public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException { + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register model group", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register model group", ex.getCause().getMessage()); + + } + + public void testRegisterModelGroupWithNoName() throws IOException { + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + + CompletableFuture future = modelGroupStep.execute(List.of(inputDataWithNoName)); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Model group name is not provided", ex.getCause().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 59fb1b173..ea1518d75 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -39,9 +39,6 @@ public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock - ActionListener registerModelActionListener; - @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -83,6 +80,7 @@ public void testRegisterModel() throws ExecutionException, InterruptedException RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { @@ -105,6 +103,7 @@ public void testRegisterModel() throws ExecutionException, InterruptedException public void testRegisterModelFailure() { RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> {