diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a1a526153..72599e25e 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -32,19 +32,19 @@ public class Template implements ToXContentObject { /** The template field name for template name */ - public static final String NAME_FIELD = "name"; + private static final String NAME_FIELD = "name"; /** The template field name for template description */ - public static final String DESCRIPTION_FIELD = "description"; + private static final String DESCRIPTION_FIELD = "description"; /** The template field name for template use case */ - public static final String USE_CASE_FIELD = "use_case"; + private static final String USE_CASE_FIELD = "use_case"; /** The template field name for template version information */ - public static final String VERSION_FIELD = "version"; + private static final String VERSION_FIELD = "version"; /** The template field name for template version */ - public static final String TEMPLATE_FIELD = "template"; + private static final String TEMPLATE_FIELD = "template"; /** The template field name for template compatibility with OpenSearch versions */ - public static final String COMPATIBILITY_FIELD = "compatibility"; + private static final String COMPATIBILITY_FIELD = "compatibility"; /** The template field name for template workflows */ - public static final String WORKFLOWS_FIELD = "workflows"; + private static final String WORKFLOWS_FIELD = "workflows"; private final String name; private final String description; diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java new file mode 100644 index 000000000..db77faf36 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; + +/** + * Step to register a model group + */ +public class ModelGroupStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "model_group"; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public ModelGroupStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) throws IOException { + + CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { + logger.info("Model group registration successful"); + registerModelGroupFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), + Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) + ) + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register model group"); + registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String modelGroupName = null; + String description = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + modelGroupName = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + default: + break; + } + } + } + + if (Stream.of(modelGroupName, description).allMatch(x -> x != null)) { + MLRegisterModelGroupInput mlInput = MLRegisterModelGroupInput.builder().name(modelGroupName).description(description).build(); + + mlClient.registerModelGroup(mlInput, actionListener); + } else { + registerModelGroupFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); + } + + return registerModelGroupFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 5aabd679f..dace6c417 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -44,6 +44,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index bf170ab9f..2c98db47b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -34,9 +34,6 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock - ActionListener registerModelActionListener; - @Mock MachineLearningNodeClient machineLearningNodeClient; diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java new file mode 100644 index 000000000..98dacc4ea --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class ModelGroupStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Map.ofEntries(Map.entry("name", "test"), Map.entry("description", "description"))); + } + + public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { + String modelGroupId = "model_group_id"; + String status = MLTaskState.CREATED.name(); + + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(modelGroupId, future.get().getContent().get("model_group_id")); + assertEquals(status, future.get().getContent().get("model_group_status")); + + } + + public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException { + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register model group", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register model group", ex.getCause().getMessage()); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 59fb1b173..0b0e406e1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -39,9 +39,6 @@ public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock - ActionListener registerModelActionListener; - @Mock MachineLearningNodeClient machineLearningNodeClient;