Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds transport request retry capability for GetMLTaskStep #179

Merged
merged 10 commits into from
Nov 23, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

Expand All @@ -69,6 +70,7 @@
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

private ClusterService clusterService;

/**
Expand All @@ -93,10 +95,15 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
Expand Down Expand Up @@ -132,7 +139,7 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ private FlowFrameworkSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the maximum number of transport request retries */
public static final Setting<Integer> MAX_REQUEST_RETRY = Setting.intSetting(
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
"plugins.flow_framework.max_request_retry",
5,
0,
20,
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
Setting.Property.NodeScope,
Setting.Property.Dynamic
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/
public class GetWorkflowResponse extends ActionResponse implements ToXContentObject {

/** The workflow state */
public WorkflowState workflowState;
/** Flag to indicate if the entire state should be returned */
public boolean allStatus;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.List;
import java.util.Map;
Expand All @@ -29,17 +31,20 @@
/**
* Step to retrieve an ML Task
*/
public class GetMLTaskStep implements WorkflowStep {
public class GetMLTaskStep extends RetryableWorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param settings the Opensearch settings
* @param clusterService the OpenSearch cluster service
* @param mlClient client to instantiate MLClient
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) {
super(settings, clusterService);
this.mlClient = mlClient;
}

Expand All @@ -48,23 +53,6 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Expand All @@ -84,7 +72,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
retryableGetMlTask(data, getMLTaskFuture, taskId, 0);
}

return getMLTaskFuture;
Expand All @@ -95,4 +83,42 @@ public String getName() {
return NAME;
}

private void retryableGetMlTask(List<WorkflowData> data, CompletableFuture<WorkflowData> getMLTaskFuture, String taskId, int retries) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
if (response.getState() != MLTaskState.COMPLETED) {
throw new IllegalStateException("MLTask is not yet completed");
} else {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
data.get(0).getWorkflowId()
)
);
}
}, exception -> {
if (shouldRetry(getMLTaskFuture, retries)) {
final int retryAdd = retries + 1;
retryableGetMlTask(data, getMLTaskFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve ML Task, maximum retries exceeded");
getMLTaskFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}

private boolean shouldRetry(CompletableFuture<WorkflowData> getMLTaskFuture, int retries) {
try {
Thread.sleep(5000);
} catch (Exception e) {
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
return retries < maxRetry;
}
joshpalis marked this conversation as resolved.
Show resolved Hide resolved

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;

/**
* Abstract retryable workflow step
*/
public abstract class RetryableWorkflowStep implements WorkflowStep {
joshpalis marked this conversation as resolved.
Show resolved Hide resolved

/** The maximum number of transport request retries */
protected volatile Integer maxRetry;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param clusterService the cluster service
*/
public RetryableWorkflowStep(Settings settings, ClusterService clusterService) {
this.maxRetry = MAX_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_REQUEST_RETRY, it -> maxRetry = it);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -29,22 +30,25 @@ public class WorkflowStepFactory {
/**
* Instantiate this class.
*
* @param settings The OpenSearch settings
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public WorkflowStepFactory(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler);
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
Expand All @@ -58,7 +62,7 @@ private void populateMap(
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -61,7 +62,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -83,7 +84,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(3, ffp.getSettings().size());
assertEquals(4, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -21,7 +24,14 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,7 +80,20 @@ public void testWorkflowStepFactoryHasValidators() throws IOException {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);

// Read in workflow-steps.json
WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -19,19 +23,25 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.TASK_ID;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class GetMLTaskStepTests extends OpenSearchTestCase {
Expand All @@ -47,7 +57,13 @@ public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);
this.getMLTaskStep = new GetMLTaskStep(mlNodeClient);
ClusterService clusterService = mock(ClusterService.class);
final Set<Setting<?>> settingsSet = Stream.concat(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Stream.of(MAX_REQUEST_RETRY))
.collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

this.getMLTaskStep = new GetMLTaskStep(Settings.EMPTY, clusterService, mlNodeClient);
this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id");
}

Expand Down
Loading