Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.13] Added new Guardrail field for remote model #624

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ private CommonValue() {}
public static final String PIPELINE_ID = "pipeline_id";
/** Pipeline Configurations */
public static final String CONFIGURATIONS = "configurations";
/** Guardrails field */
public static final String GUARDRAILS_FIELD = "guardrails";

/*
* Constants associated with resource provisioning / state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.model.Guardrails;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -32,6 +33,7 @@
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
Expand Down Expand Up @@ -95,6 +97,9 @@
xContentBuilder.field(e.getKey());
if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) {
xContentBuilder.value(e.getValue());
} else if (GUARDRAILS_FIELD.equals(e.getKey())) {
Guardrails g = (Guardrails) e.getValue();
xContentBuilder.value(g);

Check warning on line 102 in src/main/java/org/opensearch/flowframework/model/WorkflowNode.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/model/WorkflowNode.java#L101-L102

Added lines #L101 - L102 were not covered by tests
} else if (e.getValue() instanceof Map<?, ?>) {
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
} else if (e.getValue() instanceof Object[]) {
Expand Down Expand Up @@ -156,13 +161,16 @@
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
if (CONFIGURATIONS.equals(inputFieldName)) {
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, Guardrails.parse(parser));
break;

Check warning on line 166 in src/main/java/org/opensearch/flowframework/model/WorkflowNode.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/model/WorkflowNode.java#L165-L166

Added lines #L165 - L166 were not covered by tests
} else if (CONFIGURATIONS.equals(inputFieldName)) {
Map<String, Object> configurationsMap = parser.map();
try {
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
userInputs.put(inputFieldName, configurationsString);
} catch (Exception ex) {
String errorMessage = "Failed to parse configuration map";
String errorMessage = "Failed to parse" + inputFieldName + "map";

Check warning on line 173 in src/main/java/org/opensearch/flowframework/model/WorkflowNode.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/model/WorkflowNode.java#L173

Added line #L173 was not covered by tests
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
Expand All @@ -29,6 +30,7 @@

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -71,7 +73,7 @@
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -87,6 +89,7 @@
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
Expand All @@ -103,6 +106,11 @@
if (deploy != null) {
builder.deployModel(deploy);
}

if (guardRails != null) {
builder.guardrails(guardRails);

Check warning on line 111 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L111

Added line #L111 was not covered by tests
}

MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {
Expand Down
Loading