Skip to content

Commit

Permalink
Adds transport request retry capability for GetMLTaskStep (#179)
Browse files Browse the repository at this point in the history
* Added FlowFrameworkMaxRequestRetrySetting and applied this to GetMlTaskStep

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments, creating abstract class RetryableWorkflowStep to initialize the setting update consumer for max retry request setting, passing settings and clusterservice to retryable workflow step instead

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments

Signed-off-by: Joshua Palis <[email protected]>

* Removing retry utils class

Signed-off-by: Joshua Palis <[email protected]>

* Fixing failure unit tests to ensure thread sleep isnt invoked

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments, changing setting name, improving javadoc, removing max value from setting, cancelling future on thread interuppt

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Nov 23, 2023
1 parent 0a30431 commit 1092325
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

Expand All @@ -69,6 +70,7 @@
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

private ClusterService clusterService;

/**
Expand All @@ -93,10 +95,15 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
Expand Down Expand Up @@ -132,7 +139,12 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
List<Setting<?>> settings = ImmutableList.of(
FLOW_FRAMEWORK_ENABLED,
MAX_WORKFLOWS,
WORKFLOW_REQUEST_TIMEOUT,
MAX_GET_TASK_REQUEST_RETRY
);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ private FlowFrameworkSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the maximum number of get task request retries */
public static final Setting<Integer> MAX_GET_TASK_REQUEST_RETRY = Setting.intSetting(
"plugins.flow_framework.max_get_task_request_retry",
5,
0,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/
public class GetWorkflowResponse extends ActionResponse implements ToXContentObject {

/** The workflow state */
public WorkflowState workflowState;
/** Flag to indicate if the entire state should be returned */
public boolean allStatus;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;

/**
* Abstract retryable workflow step
*/
public abstract class AbstractRetryableWorkflowStep implements WorkflowStep {

/** The maximum number of transport request retries */
protected volatile Integer maxRetry;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param clusterService the cluster service
*/
public AbstractRetryableWorkflowStep(Settings settings, ClusterService clusterService) {
this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.List;
import java.util.Map;
Expand All @@ -29,17 +32,20 @@
/**
* Step to retrieve an ML Task
*/
public class GetMLTaskStep implements WorkflowStep {
public class GetMLTaskStep extends AbstractRetryableWorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param settings the Opensearch settings
* @param clusterService the OpenSearch cluster service
* @param mlClient client to instantiate MLClient
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) {
super(settings, clusterService);
this.mlClient = mlClient;
}

Expand All @@ -48,23 +54,6 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Expand All @@ -84,7 +73,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0);
}

return getMLTaskFuture;
Expand All @@ -95,4 +84,46 @@ public String getName() {
return NAME;
}

/**
* Retryable GetMLTask
* @param workflowId the workflow id
* @param getMLTaskFuture the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
*/
protected void retryableGetMlTask(String workflowId, CompletableFuture<WorkflowData> getMLTaskFuture, String taskId, int retries) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
if (response.getState() != MLTaskState.COMPLETED) {
throw new IllegalStateException("MLTask is not yet completed");
} else {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId
)
);
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
try {
Thread.sleep(5000);
} catch (Exception e) {
FutureUtils.cancel(getMLTaskFuture);
}
final int retryAdd = retries + 1;
retryableGetMlTask(workflowId, getMLTaskFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve ML Task, maximum retries exceeded");
getMLTaskFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -29,22 +30,25 @@ public class WorkflowStepFactory {
/**
* Instantiate this class.
*
* @param settings The OpenSearch settings
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public WorkflowStepFactory(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler);
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
Expand All @@ -58,7 +62,7 @@ private void populateMap(
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -61,7 +62,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -83,7 +84,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(3, ffp.getSettings().size());
assertEquals(4, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -21,7 +24,14 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,7 +80,20 @@ public void testWorkflowStepFactoryHasValidators() throws IOException {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);

// Read in workflow-steps.json
WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json");
Expand Down
Loading

0 comments on commit 1092325

Please sign in to comment.