Skip to content

Commit

Permalink
Add model interface support for remote and local custom models (#701)
Browse files Browse the repository at this point in the history
* Add model interface support for remote and local custom models

Signed-off-by: Joshua Palis <[email protected]>

* Adding changelog and javadocs

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments, adding to 2.14 release notes, updating MODEL_INTERFACE to INTERFACE_FIELD, adding parse util test

Signed-off-by: Joshua Palis <[email protected]>

* updating RegisterRemoteModelStepTests

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed May 3, 2024
1 parent 40d9efc commit 3318d31
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
- Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658))
- Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671))
- Add optional delay parameter to no-op step ([#674](https://github.com/opensearch-project/flow-framework/pull/674))
- Add model interface support for remote and local custom models ([#701](https://github.com/opensearch-project/flow-framework/pull/701))

### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Compatible with OpenSearch 2.14.0
- Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658))
- Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671))
- Add optional delay parameter to no-op step ([#674](https://github.com/opensearch-project/flow-framework/pull/674))
- Add model interface support for remote and local custom models ([#701](https://github.com/opensearch-project/flow-framework/pull/701))

### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ private CommonValue() {}
public static final String GUARDRAILS_FIELD = "guardrails";
/** Delay field */
public static final String DELAY_FIELD = "delay";
/** Model Interface Field */
public static final String INTERFACE_FIELD = "interface";

/*
* Constants associated with resource provisioning / state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
Expand Down Expand Up @@ -164,7 +165,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, Guardrails.parse(parser));
break;
} else if (CONFIGURATIONS.equals(inputFieldName)) {
} else if (CONFIGURATIONS.equals(inputFieldName) || INTERFACE_FIELD.equals(inputFieldName)) {
Map<String, Object> configurationsMap = parser.map();
try {
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,25 @@ public static String removingBackslashesAndQuotesInArrayInJsonString(String inpu
matcher.appendTail(result);
return result.toString();
}

/**
* Takes a String to json object map and converts this to a String to String map
* @param stringToObjectMap The string to object map to be transformed
* @return the transformed map
* @throws Exception for issues processing map
*/
public static Map<String, String> convertStringToObjectMapToStringToStringMap(Map<String, Object> stringToObjectMap) throws Exception {
try (Jsonb jsonb = JsonbBuilder.create()) {
Map<String, String> stringToStringMap = new HashMap<>();
for (Map.Entry<String, Object> entry : stringToObjectMap.entrySet()) {
Object value = entry.getValue();
if (value instanceof String) {
stringToStringMap.put(entry.getKey(), (String) value);
} else {
stringToStringMap.put(entry.getKey(), jsonb.toJson(value));
}
}
return stringToStringMap;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
Expand All @@ -30,6 +34,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.threadpool.ThreadPool;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
Expand All @@ -40,6 +45,7 @@
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
Expand Down Expand Up @@ -116,6 +122,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String allConfig = (String) inputs.get(ALL_CONFIG);
String modelInterface = (String) inputs.get(INTERFACE_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

// Build register model input
Expand Down Expand Up @@ -149,6 +156,27 @@ public PlainActionFuture<WorkflowData> execute(
if (modelGroupId != null) {
mlInputBuilder.modelGroupId(modelGroupId);
}
if (modelInterface != null) {
try {
// Convert model interface string to map
BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
Map<String, Object> modelInterfaceAsMap = XContentHelper.convertToMap(
modelInterfaceBytes,
false,
MediaTypeRegistry.JSON
).v2();

// Convert to string to string map
Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
mlInputBuilder.modelInterface(parameters);

} catch (Exception ex) {
String errorMessage = "Failed to create model interface";
logger.error(errorMessage, ex);
registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST));
}

}
if (deploy != null) {
mlInputBuilder.deployModel(deploy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
Expand Down Expand Up @@ -71,7 +72,7 @@ protected Set<String> getRequiredKeys() {

@Override
protected Set<String> getOptionalKeys() {
return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD);
return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD, INTERFACE_FIELD);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -27,12 +31,14 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -76,7 +82,7 @@ public PlainActionFuture<WorkflowData> execute(
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD, INTERFACE_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -93,6 +99,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
String modelInterface = (String) inputs.get(INTERFACE_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
Expand All @@ -112,6 +119,27 @@ public PlainActionFuture<WorkflowData> execute(
if (guardRails != null) {
builder.guardrails(guardRails);
}
if (modelInterface != null) {
try {
// Convert model interface string to map
BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
Map<String, Object> modelInterfaceAsMap = XContentHelper.convertToMap(
modelInterfaceBytes,
false,
MediaTypeRegistry.JSON
).v2();

// Convert to string to string map
Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
builder.modelInterface(parameters);

} catch (Exception ex) {
String errorMessage = "Failed to create model interface";
logger.error(errorMessage, ex);
registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST));
}

}

MLRegisterModelInput mlInput = builder.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ public void testParseArbitraryStringToObjectMapToString() throws Exception {
assertEquals("{\"test-1\":{\"test-1\":\"test-1\"}}", parsedMap);
}

public void testConvertStringToObjectMapToStringToStringMap() throws Exception {
Map<String, Object> map = Map.ofEntries(Map.entry("test", Map.of("test-1", "{'test-2', 'test-3'}")));
Map<String, String> convertedMap = ParseUtils.convertStringToObjectMapToStringToStringMap(map);
assertEquals("{test={\"test-1\":\"{'test-2', 'test-3'}\"}}", convertedMap.toString());
}

public void testConditionallySubstituteWithNoPlaceholders() {
String input = "This string has no placeholders";
Map<String, WorkflowData> outputs = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -68,7 +69,11 @@ public void setUp() throws Exception {
Map.entry("function_name", "ignored"),
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg")
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(
INTERFACE_FIELD,
"{\"output\":{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"output\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"properties\":{\"name\":{\"description\":\"This is a test description field\",\"type\":\"string\"},\"dataAsMap\":{\"description\":\"This is a test description field\",\"type\":\"object\"}}}},\"status_code\":{\"description\":\"This is a test description field\",\"type\":\"integer\"}}}}}},\"input\":{\"properties\":{\"parameters\":{\"properties\":{\"messages\":{\"description\":\"This is a test description field\",\"type\":\"string\"}}}}}}"
)
),
"test-id",
"test-node-id"
Expand Down Expand Up @@ -205,6 +210,38 @@ public void testRegisterRemoteModelFailure() {

}

public void testReisterRemoteModelInterfaceFailure() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new IllegalArgumentException("Failed to register remote model"));
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

WorkflowData incorrectWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("function_name", "ignored"),
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(INTERFACE_FIELD, "{\"output\":")
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
incorrectWorkflowData.getNodeId(),
incorrectWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);
assertTrue(future.isDone());
ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to create model interface", ex.getCause().getMessage());
}

public void testRegisterRemoteModelUnSafeFailure() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
Expand Down

0 comments on commit 3318d31

Please sign in to comment.