Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds transport request retry capability for GetMLTaskStep #179

Merged
merged 10 commits into from
Nov 23, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

Expand All @@ -69,6 +70,7 @@
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

private ClusterService clusterService;

/**
Expand All @@ -93,10 +95,15 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
Expand Down Expand Up @@ -132,7 +139,12 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
List<Setting<?>> settings = ImmutableList.of(
FLOW_FRAMEWORK_ENABLED,
MAX_WORKFLOWS,
WORKFLOW_REQUEST_TIMEOUT,
MAX_GET_TASK_REQUEST_RETRY
);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ private FlowFrameworkSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the maximum number of get task request retries */
public static final Setting<Integer> MAX_GET_TASK_REQUEST_RETRY = Setting.intSetting(
"plugins.flow_framework.max_get_task_request_retry",
5,
0,
Setting.Property.NodeScope,
Setting.Property.Dynamic
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/
public class GetWorkflowResponse extends ActionResponse implements ToXContentObject {

/** The workflow state */
public WorkflowState workflowState;
/** Flag to indicate if the entire state should be returned */
public boolean allStatus;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;

/**
* Abstract retryable workflow step
*/
public abstract class AbstractRetryableWorkflowStep implements WorkflowStep {

/** The maximum number of transport request retries */
protected volatile Integer maxRetry;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param clusterService the cluster service
*/
public AbstractRetryableWorkflowStep(Settings settings, ClusterService clusterService) {
this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.List;
import java.util.Map;
Expand All @@ -29,17 +32,20 @@
/**
* Step to retrieve an ML Task
*/
public class GetMLTaskStep implements WorkflowStep {
public class GetMLTaskStep extends AbstractRetryableWorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param settings the Opensearch settings
* @param clusterService the OpenSearch cluster service
* @param mlClient client to instantiate MLClient
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) {
super(settings, clusterService);
this.mlClient = mlClient;
}

Expand All @@ -48,23 +54,6 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Expand All @@ -84,7 +73,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0);
}

return getMLTaskFuture;
Expand All @@ -95,4 +84,46 @@ public String getName() {
return NAME;
}

/**
* Retryable GetMLTask
* @param workflowId the workflow id
* @param getMLTaskFuture the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
*/
protected void retryableGetMlTask(String workflowId, CompletableFuture<WorkflowData> getMLTaskFuture, String taskId, int retries) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
if (response.getState() != MLTaskState.COMPLETED) {
throw new IllegalStateException("MLTask is not yet completed");
} else {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId
)
);
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
try {
Thread.sleep(5000);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
} catch (Exception e) {
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
FutureUtils.cancel(getMLTaskFuture);
}
final int retryAdd = retries + 1;
retryableGetMlTask(workflowId, getMLTaskFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve ML Task, maximum retries exceeded");
getMLTaskFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -29,22 +30,25 @@ public class WorkflowStepFactory {
/**
* Instantiate this class.
*
* @param settings The OpenSearch settings
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public WorkflowStepFactory(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler);
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
Expand All @@ -58,7 +62,7 @@ private void populateMap(
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -61,7 +62,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -83,7 +84,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(3, ffp.getSettings().size());
assertEquals(4, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -21,7 +24,14 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,7 +80,20 @@ public void testWorkflowStepFactoryHasValidators() throws IOException {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);

// Read in workflow-steps.json
WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json");
Expand Down
Loading
Loading