-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Owais Kazi <[email protected]>
- Loading branch information
1 parent
6ee3d53
commit c17327c
Showing
6 changed files
with
218 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; | ||
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; | ||
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; | ||
|
||
/** | ||
* Step to register a model group | ||
*/ | ||
public class ModelGroupStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); | ||
|
||
private MachineLearningNodeClient mlClient; | ||
|
||
static final String NAME = "model_group"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param mlClient client to instantiate MLClient | ||
*/ | ||
public ModelGroupStep(MachineLearningNodeClient mlClient) { | ||
this.mlClient = mlClient; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException { | ||
|
||
CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>(); | ||
|
||
ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() { | ||
@Override | ||
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { | ||
logger.info("Model group registration successful"); | ||
registerModelGroupFuture.complete( | ||
new WorkflowData( | ||
Map.ofEntries( | ||
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), | ||
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) | ||
) | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
logger.error("Failed to register model group"); | ||
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); | ||
} | ||
}; | ||
|
||
String modelGroupName = null; | ||
String description = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case NAME_FIELD: | ||
modelGroupName = (String) content.get(NAME_FIELD); | ||
break; | ||
case DESCRIPTION: | ||
description = (String) content.get(DESCRIPTION); | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
} | ||
|
||
if (Stream.of(modelGroupName, description).allMatch(x -> x != null)) { | ||
MLRegisterModelGroupInput mlInput = MLRegisterModelGroupInput.builder().name(modelGroupName).description(description).build(); | ||
|
||
mlClient.registerModelGroup(mlInput, actionListener); | ||
} else { | ||
registerModelGroupFuture.completeExceptionally( | ||
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) | ||
); | ||
} | ||
|
||
return registerModelGroupFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.MLTaskState; | ||
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; | ||
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.verify; | ||
|
||
public class ModelGroupStepTests extends OpenSearchTestCase { | ||
private WorkflowData inputData = WorkflowData.EMPTY; | ||
|
||
@Mock | ||
MachineLearningNodeClient machineLearningNodeClient; | ||
|
||
@Override | ||
public void setUp() throws Exception { | ||
super.setUp(); | ||
|
||
MockitoAnnotations.openMocks(this); | ||
|
||
inputData = new WorkflowData(Map.ofEntries(Map.entry("name", "test"), Map.entry("description", "description"))); | ||
} | ||
|
||
public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { | ||
String modelGroupId = "model_group_id"; | ||
String status = MLTaskState.CREATED.name(); | ||
|
||
ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); | ||
|
||
ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1); | ||
MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); | ||
actionListener.onResponse(output); | ||
return null; | ||
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); | ||
|
||
CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData)); | ||
|
||
verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); | ||
|
||
assertTrue(future.isDone()); | ||
assertEquals(modelGroupId, future.get().getContent().get("model_group_id")); | ||
assertEquals(status, future.get().getContent().get("model_group_status")); | ||
|
||
} | ||
|
||
public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException { | ||
ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); | ||
|
||
ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1); | ||
actionListener.onFailure(new FlowFrameworkException("Failed to register model group", RestStatus.INTERNAL_SERVER_ERROR)); | ||
return null; | ||
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); | ||
|
||
CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData)); | ||
|
||
verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); | ||
|
||
assertTrue(future.isCompletedExceptionally()); | ||
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); | ||
assertTrue(ex.getCause() instanceof FlowFrameworkException); | ||
assertEquals("Failed to register model group", ex.getCause().getMessage()); | ||
|
||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters