Skip to content

Commit

Permalink
Refactor RetryableWorkflowStep to avoid recursion and subclass specif…
Browse files Browse the repository at this point in the history
…ics (#298)

* Refactor RetryableWorkflowStep to avoid recursion and subclass specifics

Signed-off-by: Daniel Widdis <[email protected]>

* Fix tests

Signed-off-by: Daniel Widdis <[email protected]>

* Support function_name field in Register Local Model Step

Signed-off-by: Daniel Widdis <[email protected]>

* Better handling of provisioning exception message on cancellation

Signed-off-by: Daniel Widdis <[email protected]>

* Throw an exception with message on cancellation.

Signed-off-by: Daniel Widdis <[email protected]>

* Add test coverage

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jan 2, 2024
1 parent 0e9156b commit 39c3f48
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public Collection<Object> createComponents(
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
threadPool,
clusterService,
client,
mlClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -213,12 +214,19 @@ private void executeWorkflow(List<ProcessNode> workflowSequence, String workflow
);
} catch (Exception ex) {
logger.error("Provisioning failed for workflow: {}", workflowId, ex);
String errorMessage;
if (ex instanceof CancellationException) {
errorMessage = "A step in the workflow was cancelled.";
} else if (ex.getCause() != null) {
errorMessage = ex.getCause().getMessage();
} else {
errorMessage = ex.getMessage();
}
flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
Map.ofEntries(
Map.entry(STATE_FIELD, State.FAILED),
// TODO: potentially improve the error message here
Map.entry(ERROR_FIELD, ex.getMessage()),
Map.entry(ERROR_FIELD, errorMessage),
Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.FAILED),
Map.entry(PROVISION_END_TIME_FIELD, Instant.now().toEpochMilli())
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTask;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

/**
Expand All @@ -39,20 +40,24 @@ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep {
protected volatile Integer maxRetry;
private final MachineLearningNodeClient mlClient;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private ThreadPool threadPool;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param threadPool The OpenSearch thread pool
* @param clusterService the cluster service
* @param mlClient machine learning client
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
protected AbstractRetryableWorkflowStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.threadPool = threadPool;
this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it);
this.mlClient = mlClient;
Expand All @@ -65,82 +70,93 @@ protected AbstractRetryableWorkflowStep(
* @param nodeId the workflow node id
* @param future the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
* @param workflowStep the workflow step which requires a retry get ml task functionality
*/
protected void retryableGetMlTask(
String workflowId,
String nodeId,
CompletableFuture<WorkflowData> future,
String taskId,
int retries,
String workflowStep
) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
MLTaskState currentState = response.getState();
if (currentState != MLTaskState.COMPLETED) {
if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) {
// Model registration failed or completed with errors
String errorMessage = workflowStep + " failed with error : " + response.getError();
AtomicInteger retries = new AtomicInteger();
CompletableFuture.runAsync(() -> {
while (retries.getAndIncrement() < this.maxRetry && !future.isDone()) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
switch (response.getState()) {
case COMPLETED:
try {
String resourceName = getResourceByWorkflowStep(getName());
String id = getResourceId(response);
logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id);
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, id),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
break;
case FAILED:
case COMPLETED_WITH_ERROR:
String errorMessage = workflowStep + " failed with error : " + response.getError();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
break;
case CANCELLED:
errorMessage = workflowStep + " task was cancelled.";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
break;
default:
// Task started or running, do nothing
}
}, exception -> {
String errorMessage = workflowStep + " failed with error : " + exception.getMessage();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
} else {
// Task still in progress, attempt retry
throw new IllegalStateException(workflowStep + " is not yet completed");
}
} else {
try {
logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId());
String resourceName = getResourceByWorkflowStep(getName());
String id;
if (getName().equals(DEPLOY_MODEL.getWorkflowStep())) {
id = response.getModelId();
} else {
id = response.getTaskId();
}
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
}));
// Wait long enough for future to possibly complete
try {
Thread.sleep(5000);
} catch (Exception e) {
} catch (InterruptedException e) {
FutureUtils.cancel(future);
Thread.currentThread().interrupt();
}
retryableGetMlTask(workflowId, nodeId, future, taskId, retries + 1, workflowStep);
} else {
logger.error("Failed to retrieve" + workflowStep + ",maximum retries exceeded");
future.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
}
}));
if (!future.isDone()) {
String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
}
}, threadPool.executor(PROVISION_THREAD_POOL));
}

/**
* Returns the resourceId associated with the task
* @param response The Task response
* @return the resource ID, such as a model id
*/
protected abstract String getResourceId(MLTask response);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.threadpool.ThreadPool;

import java.util.Collections;
import java.util.Map;
Expand All @@ -42,17 +44,19 @@ public class DeployModelStep extends AbstractRetryableWorkflowStep {
/**
* Instantiate this class
* @param settings The OpenSearch settings
* @param threadPool The OpenSearch thread pool
* @param clusterService The cluster service
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public DeployModelStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(settings, clusterService, mlClient, flowFrameworkIndicesHandler);
super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler);
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}
Expand All @@ -74,7 +78,7 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
String taskId = mlDeployModelResponse.getTaskId();

// Attempt to retrieve the model ID
retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, 0, "Deploy model");
retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, "Deploy model");
}

@Override
Expand Down Expand Up @@ -105,6 +109,11 @@ public void onFailure(Exception e) {
return deployModelFuture;
}

@Override
protected String getResourceId(MLTask response) {
return response.getModelId();
}

@Override
public String getName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand All @@ -26,6 +28,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.Set;
Expand All @@ -35,6 +38,7 @@
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
Expand All @@ -60,17 +64,19 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep {
/**
* Instantiate this class
* @param settings The OpenSearch settings
* @param threadPool The OpenSearch thread pool
* @param clusterService The cluster service
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public RegisterLocalModelStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(settings, clusterService, mlClient, flowFrameworkIndicesHandler);
super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler);
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}
Expand Down Expand Up @@ -98,7 +104,6 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
currentNodeId,
registerLocalModelFuture,
taskId,
0,
"Local model registration"
);
}
Expand All @@ -120,7 +125,7 @@ public void onFailure(Exception e) {
MODEL_CONTENT_HASH_VALUE,
URL
);
Set<String> optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG);
Set<String> optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, FUNCTION_NAME);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -142,6 +147,7 @@ public void onFailure(Exception e) {
FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE));
String allConfig = (String) inputs.get(ALL_CONFIG);
String url = (String) inputs.get(URL);
String functionName = (String) inputs.get(FUNCTION_NAME);

// Create Model configuration
TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder()
Expand All @@ -158,13 +164,18 @@ public void onFailure(Exception e) {
.modelName(modelName)
.version(modelVersion)
.modelFormat(modelFormat)
.modelGroupId(modelGroupId)
.hashValue(modelContentHashValue)
.modelConfig(modelConfig)
.url(url);
if (description != null) {
mlInputBuilder.description(description);
}
if (modelGroupId != null) {
mlInputBuilder.modelGroupId(modelGroupId);
}
if (functionName != null) {
mlInputBuilder.functionName(FunctionName.from(functionName));
}

MLRegisterModelInput mlInput = mlInputBuilder.build();

Expand All @@ -175,6 +186,11 @@ public void onFailure(Exception e) {
return registerLocalModelFuture;
}

@Override
protected String getResourceId(MLTask response) {
return response.getModelId();
}

@Override
public String getName() {
return NAME;
Expand Down
Loading

0 comments on commit 39c3f48

Please sign in to comment.