Skip to content

Commit

Permalink
add plugin enabled setting check for get status api
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 9, 2023
1 parent b8b131c commit af1424f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ public List<RestHandler> getRestHandlers(
return ImmutableList.of(
new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting)
new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,25 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.GetWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED;

/**
* Rest Action to facilitate requests to get a workflow status
Expand All @@ -34,11 +39,14 @@ public class RestGetWorkflowAction extends BaseRestHandler {

private static final String GET_WORKFLOW_ACTION = "get_workflow";
private static final Logger logger = LogManager.getLogger(RestGetWorkflowAction.class);
private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

/**
* Instantiates a new RestGetWorkflowAction
*/
public RestGetWorkflowAction() {}
public RestGetWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) {
this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -47,21 +55,45 @@ public String getName() {

@Override
protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
// Validate content
if (request.hasContent()) {
throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST);
}
// Validate params
String workflowId = request.param(WORKFLOW_ID);
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}

String rawPath = request.rawPath();
boolean all = request.paramAsBoolean("_all", false);
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, all);
return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel));
try {
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
}

// Validate content
if (request.hasContent()) {
throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST);
}
// Validate params
String workflowId = request.param(WORKFLOW_ID);
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}

boolean all = request.paramAsBoolean("_all", false);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, all);
return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
}, exception -> {
try {
FlowFrameworkException ex = (FlowFrameworkException) exception;
XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder));
} catch (IOException e) {
logger.error("Failed to send back provision workflow exception", e);
}
}));

} catch (FlowFrameworkException ex) {
return channel -> channel.sendResponse(
new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

if (request.isDryRun()) {
try {
//generating random workflowId only for validation purpose
// generating random workflowId only for validation purpose
String uniqueID = UUID.randomUUID().toString();
validateWorkflows(templateWithUser, uniqueID);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public GetWorkflowResponse(WorkflowState workflowState, boolean allStatus) {
this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId())
.error(workflowState.getError())
.state(workflowState.getState())
.resourcesCreated(workflowState.resourcesCreated())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
* @param dryRun flag to indicate if validation is necessary
* @param all whether the get request is looking for all fields in status
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, boolean dryRun, boolean all) {
this.workflowId = workflowId;
Expand All @@ -71,9 +72,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template,
* @param all whether the get request is looking for all fields in status
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, boolean all) {
this.workflowId = workflowId;
this.template = template;
this.all = all;
this(workflowId, template, all, false);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ public void testFailedGraphValidation() {

List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123");
FlowFrameworkException ex = expectThrows(
FlowFrameworkException.class,
() -> workflowProcessSorter.validateGraph(sortedProcessNodes)
FlowFrameworkException.class,
() -> workflowProcessSorter.validateGraph(sortedProcessNodes)
);
assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage());
assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus());
Expand Down

0 comments on commit af1424f

Please sign in to comment.