Skip to content

Commit

Permalink
Support deploy=true on RegisterRemoteModelStep (#340)
Browse files Browse the repository at this point in the history
* Support deploy=true on RegisterRemoteModelStep

Signed-off-by: Daniel Widdis <[email protected]>

* Hardcode function_name to remote

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 29, 2023
1 parent c3ba8f0 commit a45b38c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ private CommonValue() {}
public static final String MODEL_GROUP_STATUS = "model_group_status";
/** Description field */
public static final String DESCRIPTION_FIELD = "description";
/** Description field */
public static final String DEPLOY_FIELD = "deploy";
/** Model format field */
public static final String MODEL_FORMAT = "model_format";
/** Model content hash value field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -21,13 +22,12 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -68,53 +68,8 @@ public CompletableFuture<WorkflowData> execute(

CompletableFuture<WorkflowData> registerRemoteModelFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {

try {
logger.info("Remote Model registration successful");
String resourceName = getResourceByWorkflowStep(getName());
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(response -> {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
registerRemoteModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register remote model");
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Set<String> requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD);
Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -126,13 +81,13 @@ public void onFailure(Exception e) {
);

String modelName = (String) inputs.get(NAME_FIELD);
FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(functionName)
.functionName(FunctionName.REMOTE)
.modelName(modelName)
.connectorId(connectorId);

Expand All @@ -142,9 +97,82 @@ public void onFailure(Exception e) {
if (description != null) {
builder.description(description);
}
if (deploy != null) {
builder.deployModel(deploy);
}
MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, actionListener);
mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {

try {
logger.info("Remote Model registration successful");
String resourceName = getResourceByWorkflowStep(getName());
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(response -> {
// If we deployed, simulate the deploy step has been called
if (Boolean.TRUE.equals(deploy)) {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
DeployModelStep.NAME,
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(deployUpdateResponse -> {
completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse);
}, deployUpdateException -> {
logger.error("Failed to update simulated deploy step resource", deployUpdateException);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(
deployUpdateException.getMessage(),
ExceptionsHelper.status(deployUpdateException)
)
);
})
);
} else {
completeRegisterFuture(response, resourceName, mlRegisterModelResponse);
}
}, exception -> {
logger.error("Failed to update new created resource", exception);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))
);
}
}

void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
registerRemoteModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register remote model");
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
});

} catch (FlowFrameworkException e) {
registerRemoteModelFuture.completeExceptionally(e);
Expand Down
1 change: 0 additions & 1 deletion src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
"register_remote_model": {
"inputs": [
"name",
"function_name",
"connector_id"
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.mockito.MockitoAnnotations;

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -59,7 +60,7 @@ public void setUp() throws Exception {
this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient, flowFrameworkIndicesHandler);
this.workflowData = new WorkflowData(
Map.ofEntries(
Map.entry("function_name", "remote"),
Map.entry("function_name", "ignored"),
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg")
Expand Down Expand Up @@ -96,14 +97,63 @@ public void testRegisterRemoteModelSuccess() throws Exception {
);

verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
// only updates register resource
verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

assertTrue(future.isDone());
assertTrue(!future.isCompletedExceptionally());
assertFalse(future.isCompletedExceptionally());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));

}

public void testRegisterAndDeployRemoteModelSuccess() throws Exception {

String taskId = "abcd";
String modelId = "efgh";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId);
actionListener.onResponse(output);
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

WorkflowData deployWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(DEPLOY_FIELD, true)
),
"test-id",
"test-node-id"
);

CompletableFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
deployWorkflowData.getNodeId(),
deployWorkflowData,
Collections.emptyMap(),
Collections.emptyMap()
);

verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
// updates both register and deploy resources
verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

assertTrue(future.isDone());
assertFalse(future.isCompletedExceptionally());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
}

public void testRegisterRemoteModelFailure() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
Expand Down Expand Up @@ -137,7 +187,7 @@ public void testMissingInputs() {
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs ["));
for (String s : new String[] { "name", "function_name", CONNECTOR_ID }) {
for (String s : new String[] { "name", CONNECTOR_ID }) {
assertTrue(ex.getCause().getMessage().contains(s));
}
assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));
Expand Down

0 comments on commit a45b38c

Please sign in to comment.