Skip to content

Commit

Permalink
[Backport 2.x] Adds transport request retry capability for GetMLTaskS…
Browse files Browse the repository at this point in the history
…tep (#193)
  • Loading branch information
opensearch-trigger-bot[bot] committed Nov 23, 2023
1 parent db30b87 commit 862fefe
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

Expand All @@ -69,6 +70,7 @@
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

private ClusterService clusterService;

/**
Expand All @@ -93,10 +95,15 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
Expand Down Expand Up @@ -132,7 +139,12 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
List<Setting<?>> settings = ImmutableList.of(
FLOW_FRAMEWORK_ENABLED,
MAX_WORKFLOWS,
WORKFLOW_REQUEST_TIMEOUT,
MAX_GET_TASK_REQUEST_RETRY
);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ private FlowFrameworkSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the maximum number of get task request retries */
public static final Setting<Integer> MAX_GET_TASK_REQUEST_RETRY = Setting.intSetting(
"plugins.flow_framework.max_get_task_request_retry",
5,
0,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/
public class GetWorkflowResponse extends ActionResponse implements ToXContentObject {

/** The workflow state */
public WorkflowState workflowState;
/** Flag to indicate if the entire state should be returned */
public boolean allStatus;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;

/**
* Abstract retryable workflow step
*/
public abstract class AbstractRetryableWorkflowStep implements WorkflowStep {

/** The maximum number of transport request retries */
protected volatile Integer maxRetry;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param clusterService the cluster service
*/
public AbstractRetryableWorkflowStep(Settings settings, ClusterService clusterService) {
this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.List;
import java.util.Map;
Expand All @@ -29,17 +32,20 @@
/**
* Step to retrieve an ML Task
*/
public class GetMLTaskStep implements WorkflowStep {
public class GetMLTaskStep extends AbstractRetryableWorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param settings the Opensearch settings
* @param clusterService the OpenSearch cluster service
* @param mlClient client to instantiate MLClient
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) {
super(settings, clusterService);
this.mlClient = mlClient;
}

Expand All @@ -48,23 +54,6 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Expand All @@ -84,7 +73,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0);
}

return getMLTaskFuture;
Expand All @@ -95,4 +84,46 @@ public String getName() {
return NAME;
}

/**
* Retryable GetMLTask
* @param workflowId the workflow id
* @param getMLTaskFuture the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
*/
protected void retryableGetMlTask(String workflowId, CompletableFuture<WorkflowData> getMLTaskFuture, String taskId, int retries) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
if (response.getState() != MLTaskState.COMPLETED) {
throw new IllegalStateException("MLTask is not yet completed");
} else {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId
)
);
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
try {
Thread.sleep(5000);
} catch (Exception e) {
FutureUtils.cancel(getMLTaskFuture);
}
final int retryAdd = retries + 1;
retryableGetMlTask(workflowId, getMLTaskFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve ML Task, maximum retries exceeded");
getMLTaskFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -29,22 +30,25 @@ public class WorkflowStepFactory {
/**
* Instantiate this class.
*
* @param settings The OpenSearch settings
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public WorkflowStepFactory(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler);
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
Expand All @@ -58,7 +62,7 @@ private void populateMap(
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -61,7 +62,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -83,7 +84,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(3, ffp.getSettings().size());
assertEquals(4, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -21,7 +24,14 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,7 +80,20 @@ public void testWorkflowStepFactoryHasValidators() throws IOException {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);

// Read in workflow-steps.json
WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json");
Expand Down
Loading

0 comments on commit 862fefe

Please sign in to comment.