Skip to content

Commit

Permalink
Added optional field and handled them for creating the builder
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 27, 2023
1 parent b043f11 commit 864d0d7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,10 @@ private CommonValue() {}
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
public static final String BACKEND_ROLES_FIELD = "backend_roles";
/** Access mode for the model */
public static final String MODEL_ACCESS_MODE = "access_mode";
/** Add all backend roles */
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles";
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES;
import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
Expand Down Expand Up @@ -75,6 +80,9 @@ public void onFailure(Exception e) {

String modelGroupName = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = false;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
Expand All @@ -87,22 +95,42 @@ public void onFailure(Exception e) {
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case BACKEND_ROLES_FIELD:
backendRoles = (List<String>) content.get(BACKEND_ROLES_FIELD);
case MODEL_ACCESS_MODE:
modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE);
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES);
default:
break;
}
}
}

if (Stream.of(modelGroupName, description).allMatch(x -> x != null)) {
MLRegisterModelGroupInput mlInput = MLRegisterModelGroupInput.builder().name(modelGroupName).description(description).build();

mlClient.registerModelGroup(mlInput, actionListener);
} else {
if (modelGroupName == null) {
registerModelGroupFuture.completeExceptionally(

Check warning on line 111 in src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java#L111

Added line #L111 was not covered by tests
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST)
);
}

MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
if (description != null) {
builder.description(description);
}
if (backendRoles != null && backendRoles.size() > 0) {
builder.backendRoles(backendRoles);
}
if (modelAccessMode != null) {
builder.modelAccessMode(modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.isAddAllBackendRoles(isAddAllBackendRoles);
}
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);

return registerModelGroupFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
*/
package org.opensearch.flowframework.workflow;

import com.google.common.collect.ImmutableList;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
Expand Down Expand Up @@ -43,8 +45,15 @@ public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Map.ofEntries(Map.entry("name", "test"), Map.entry("description", "description")));
inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Map.entry("description", "description"),
Map.entry("backend_roles", ImmutableList.of("role-1")),
Map.entry("access_mode", AccessMode.PUBLIC),
Map.entry("add_all_backend_roles", false)
)
);
}

public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException {
Expand Down

0 comments on commit 864d0d7

Please sign in to comment.