From 10475fff5bef22a45f99151ddcf06f4e3cd9c7e0 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Thu, 9 Nov 2023 15:08:32 -0500 Subject: [PATCH] add plugin enabled setting check for get status api Signed-off-by: Amit Galitzky --- .../flowframework/FlowFrameworkPlugin.java | 3 +- .../rest/RestGetWorkflowAction.java | 65 ++++++++++++++----- .../CreateWorkflowTransportAction.java | 2 +- .../transport/GetWorkflowResponse.java | 1 + .../transport/WorkflowRequest.java | 5 +- .../FlowFrameworkPluginTests.java | 4 +- .../workflow/WorkflowProcessSorterTests.java | 4 +- 7 files changed, 59 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 318e1015f..225b392a0 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -110,7 +110,8 @@ public List getRestHandlers( return ImmutableList.of( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), - new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting) + new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting) ); } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 1f1a9295d..876b1f3d5 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -12,13 +12,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; import java.io.IOException; import java.util.List; @@ -26,6 +30,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; /** * Rest Action to facilitate requests to get a workflow status @@ -34,11 +39,15 @@ public class RestGetWorkflowAction extends BaseRestHandler { private static final String GET_WORKFLOW_ACTION = "get_workflow"; private static final Logger logger = LogManager.getLogger(RestGetWorkflowAction.class); + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; /** * Instantiates a new RestGetWorkflowAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled */ - public RestGetWorkflowAction() {} + public RestGetWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } @Override public String getName() { @@ -47,21 +56,45 @@ public String getName() { @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - // Validate content - if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); - } - // Validate params - String workflowId = request.param(WORKFLOW_ID); - if (workflowId == null) { - throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); - } - String rawPath = request.rawPath(); - boolean all = request.paramAsBoolean("_all", false); - // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, all); - return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + boolean all = request.paramAsBoolean("_all", false); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, all); + return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = (FlowFrameworkException) exception; + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 0beae6e04..8e9315cc4 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -87,7 +87,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); FlowFrameworkException ex = expectThrows( - FlowFrameworkException.class, - () -> workflowProcessSorter.validateGraph(sortedProcessNodes) + FlowFrameworkException.class, + () -> workflowProcessSorter.validateGraph(sortedProcessNodes) ); assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus());