From a08cdbcce2c31a6db5a6517489d0429070761075 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Mon, 1 Apr 2024 12:00:27 -0700 Subject: [PATCH] [POC] Adding Orchestrate API and Search Response Processor Step (#619) * Fixing ingest pipeline integ test (#614) Signed-off-by: Joshua Palis * Adding initial rest, transport actions for orchestration. Search response processor step Signed-off-by: Joshua Palis * Fixing transport action Signed-off-by: Joshua Palis * reverting change to resttestcase Signed-off-by: Joshua Palis * Adding javadocs Signed-off-by: Joshua Palis * fixing checkstyle Signed-off-by: Joshua Palis * removing extra common value Signed-off-by: Joshua Palis * Fixing errors Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 9 +- .../flowframework/common/CommonValue.java | 4 + .../flowframework/model/WorkflowNode.java | 3 +- .../rest/RestOrchestrateAction.java | 110 ++++++++++ .../CreateWorkflowTransportAction.java | 9 +- .../transport/OrchestrateAction.java | 30 +++ .../transport/OrchestrateRequest.java | 82 ++++++++ .../transport/OrchestrateTransportAction.java | 189 ++++++++++++++++++ .../workflow/SearchResponseProcessorStep.java | 127 ++++++++++++ .../workflow/WorkflowProcessSorter.java | 9 + .../workflow/WorkflowStepFactory.java | 21 ++ .../FlowFrameworkPluginTests.java | 4 +- .../model/WorkflowValidatorTests.java | 4 +- .../CreateWorkflowTransportActionTests.java | 4 +- 14 files changed, 597 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestOrchestrateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/OrchestrateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/OrchestrateRequest.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/OrchestrateTransportAction.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/SearchResponseProcessorStep.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 4c8486b7e..118bbdad1 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -33,6 +33,7 @@ import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestGetWorkflowStepAction; +import org.opensearch.flowframework.rest.RestOrchestrateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction; @@ -48,6 +49,8 @@ import org.opensearch.flowframework.transport.GetWorkflowStepAction; import org.opensearch.flowframework.transport.GetWorkflowStepTransportAction; import org.opensearch.flowframework.transport.GetWorkflowTransportAction; +import org.opensearch.flowframework.transport.OrchestrateAction; +import org.opensearch.flowframework.transport.OrchestrateTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; @@ -156,7 +159,8 @@ public List getRestHandlers( new RestGetWorkflowStateAction(flowFrameworkSettings), new RestGetWorkflowAction(flowFrameworkSettings), new RestGetWorkflowStepAction(flowFrameworkSettings), - new RestSearchWorkflowStateAction(flowFrameworkSettings) + new RestSearchWorkflowStateAction(flowFrameworkSettings), + new RestOrchestrateAction() ); } @@ -171,7 +175,8 @@ public List getRestHandlers( new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class), new ActionHandler<>(GetWorkflowStepAction.INSTANCE, GetWorkflowStepTransportAction.class), - new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class) + new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class), + new ActionHandler<>(OrchestrateAction.INSTANCE, OrchestrateTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 667c94819..90ba5b37f 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -168,6 +168,10 @@ private CommonValue() {} public static final String PIPELINE_ID = "pipeline_id"; /** Pipeline Configurations */ public static final String CONFIGURATIONS = "configurations"; + /** Processor Config*/ + public static final String PROCESSOR_CONFIG = "processor_config"; + /** Processor Tag */ + public static final String TAG = "tag"; /** Indexes for knn query **/ public static final String INPUT_INDEX = "input_index"; diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 15d52ccd1..b466fd661 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -32,6 +32,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.CommonValue.PROCESSOR_CONFIG; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -156,7 +157,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (CONFIGURATIONS.equals(inputFieldName)) { + if (CONFIGURATIONS.equals(inputFieldName) || PROCESSOR_CONFIG.equals(inputFieldName)) { Map configurationsMap = parser.map(); try { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestOrchestrateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestOrchestrateAction.java new file mode 100644 index 000000000..dfbd932e1 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestOrchestrateAction.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.OrchestrateAction; +import org.opensearch.flowframework.transport.OrchestrateRequest; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +/** + * Rest action to orchestate workflows + */ +public class RestOrchestrateAction extends BaseRestHandler { + + private static final Logger logger = LogManager.getLogger(RestProvisionWorkflowAction.class); + + private static final String ORCHESTRATE_ACTION = "orchestrate_action"; + + /** + * Creates a new RestOrchestrateAction instance + */ + public RestOrchestrateAction() {} + + @Override + public String getName() { + return ORCHESTRATE_ACTION; + } + + @Override + public List routes() { + return List.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_orchestrate")) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Get workflow ID + String workflowId = request.param(WORKFLOW_ID); + + try { + + // Validate params + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + // Retrieve string to string map from content + Map userInputs = Collections.emptyMap(); + if (request.hasContent()) { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + userInputs = ParseUtils.parseStringToStringMap(parser); + } + // Create Request object and execute transport action with client to pass ID and values + OrchestrateRequest orchestrateRequest = new OrchestrateRequest(workflowId, userInputs); + return channel -> client.execute(OrchestrateAction.INSTANCE, orchestrateRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException("Failed to get workflow.", ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + String errorMessage = "IOException: Failed to send back orchestrate exception"; + logger.error(errorMessage, e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), errorMessage)); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 9dabd0c4a..4c174fe36 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -36,6 +36,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.plugins.PluginsService; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -63,6 +64,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null, Collections.emptyMap()); workflowProcessSorter.validate(sortedNodes, pluginsService); diff --git a/src/main/java/org/opensearch/flowframework/transport/OrchestrateAction.java b/src/main/java/org/opensearch/flowframework/transport/OrchestrateAction.java new file mode 100644 index 000000000..96bc3a2ee --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/OrchestrateAction.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External action for public facing RestOrchestrateAction + */ +public class OrchestrateAction extends ActionType { + + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/orchestrate"; + + /** An instance of this action */ + public static final OrchestrateAction INSTANCE = new OrchestrateAction(); + + private OrchestrateAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/OrchestrateRequest.java b/src/main/java/org/opensearch/flowframework/transport/OrchestrateRequest.java new file mode 100644 index 000000000..fd20bc1cb --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/OrchestrateRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.Map; + +/** + * Transport request for orchestrate API + */ +public class OrchestrateRequest extends ActionRequest { + + /** + * The documentId of the workflow entry within the Global Context index + */ + private String workflowId; + + /** + * User inputs map + */ + private Map userInputs; + + /** + * Creates a new orchestrate request + * @param workflowId the workflow ID + * @param userInputs the user inputs to substitute values for in the template + */ + public OrchestrateRequest(String workflowId, Map userInputs) { + this.workflowId = workflowId; + this.userInputs = userInputs; + } + + /** + * Creates a new orchestrate request from stream input + * @param in the stream input + * @throws IOException on error reading from the stream input + */ + public OrchestrateRequest(StreamInput in) throws IOException { + super(in); + this.workflowId = in.readString(); + this.userInputs = in.readMap(StreamInput::readString, StreamInput::readString); + } + + /** + * Returns the workflow ID + * @return the workflow ID + */ + public String getWorkflowId() { + return this.workflowId; + } + + /** + * Returns the user inputs map + * @return the user inputs map + */ + public Map getUserInputs() { + return this.userInputs; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(workflowId); + out.writeMap(userInputs, StreamOutput::writeString, StreamOutput::writeString); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/OrchestrateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/OrchestrateTransportAction.java new file mode 100644 index 000000000..89002177b --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/OrchestrateTransportAction.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.plugins.PluginsService; +import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.SEARCH_RESPONSE; + +/** + * Orchestrate transport action + */ +public class OrchestrateTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(OrchestrateTransportAction.class); + + private final Client client; + private final WorkflowProcessSorter workflowProcessSorter; + private final PluginsService pluginsService; + private final SearchPipelineService searchPipelineService; + + /** + * Creates a new Orchestrate Transport Action instance + * @param transportService the transport service + * @param actionFilters action filters + * @param threadPool the thread pool + * @param client the opensearch client + * @param workflowProcessSorter the workflow process sorter + * @param pluginsService the plugins service + * @param searchPipelineService the search pipeline service + */ + @Inject + public OrchestrateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + WorkflowProcessSorter workflowProcessSorter, + PluginsService pluginsService, + SearchPipelineService searchPipelineService + ) { + super(OrchestrateAction.NAME, transportService, actionFilters, OrchestrateRequest::new); + this.client = client; + this.workflowProcessSorter = workflowProcessSorter; + this.pluginsService = pluginsService; + this.searchPipelineService = searchPipelineService; + } + + @Override + protected void doExecute(Task task, OrchestrateRequest request, ActionListener listener) { + + // Get Template + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + logger.info("Querying workflow from global context: {}", workflowId); + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + return; + } + + // Parse template from document source + Template template = Template.parse(response.getSourceAsString()); + + // Hacky way to get the search pipeline service into the workflow step factory + workflowProcessSorter.updateWorkflowStepFactory(searchPipelineService); + + // Sort and validate graph + Workflow searchWorkflow = template.workflows().get("search"); + List searchProcessSequence = workflowProcessSorter.sortProcessNodes( + searchWorkflow, + workflowId, + request.getUserInputs() // pass in request body conten to prcess nodes + ); + workflowProcessSorter.validate(searchProcessSequence, pluginsService); + + // Execute the workflow + String currentStepId = ""; + try { + Map> workflowFutureMap = new LinkedHashMap<>(); + for (ProcessNode processNode : searchProcessSequence) { + List predecessors = processNode.predecessors(); + + logger.info( + "Queueing process [{}].{}", + processNode.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + workflowFutureMap.put(processNode.id(), processNode.execute()); + } + + // Attempt to complete each workflow step future, may throw a ExecutionException if any step completes exceptionally + // Additionally track each returned object of type search response and return the last one + SearchResponse searchResponse = null; + for (Map.Entry> entry : workflowFutureMap.entrySet()) { + currentStepId = entry.getKey(); + WorkflowData result = (WorkflowData) entry.getValue().actionGet(); + if (result.getContent().containsKey(SEARCH_RESPONSE)) { + searchResponse = (SearchResponse) result.getContent().get(SEARCH_RESPONSE); + } + } + if (searchResponse == null) { + listener.onFailure( + new FlowFrameworkException("The search workflow did not return a search response", RestStatus.BAD_REQUEST) + ); + } else { + logger.info("Search completed successfully for workflow {}", workflowId); + listener.onResponse(searchResponse); + } + } catch (Exception ex) { + RestStatus status; + if (ex instanceof FlowFrameworkException) { + status = ((FlowFrameworkException) ex).getRestStatus(); + } else { + status = ExceptionsHelper.status(ex); + } + logger.error("Search failed for workflow {} during step {}.", workflowId, currentStepId, ex); + String errorMessage = (ex.getCause() == null ? ex.getClass().getName() : ex.getCause().getClass().getName()) + + " during step " + + currentStepId + + ", restStatus: " + + status.toString(); + listener.onFailure(new FlowFrameworkException(errorMessage, status)); + } + }, exception -> { + if (exception instanceof FlowFrameworkException) { + logger.error("Workflow validation failed for workflow {}", workflowId); + listener.onFailure(exception); + } else { + String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + })); + } catch (Exception e) { + String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/SearchResponseProcessorStep.java b/src/main/java/org/opensearch/flowframework/workflow/SearchResponseProcessorStep.java new file mode 100644 index 000000000..104a3d761 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/SearchResponseProcessorStep.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.pipeline.PipelineConfiguration; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.search.pipeline.Processor.PipelineSource; +import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROCESSOR_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.SEARCH_REQUEST; +import static org.opensearch.flowframework.common.CommonValue.SEARCH_RESPONSE; +import static org.opensearch.flowframework.common.CommonValue.TAG; +import static org.opensearch.flowframework.common.CommonValue.TYPE; + +/** + * Step to create a search response processor + */ +public class SearchResponseProcessorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(SearchResponseProcessorStep.class); + private final SearchPipelineService searchPipelineService; + /** + * The name of the SearchResponseProcessor step + */ + public static final String NAME = "search_response_processor"; + + /** + * Creates a new SearchResponseProcessorStep + * @param searchPipelineService the search pipeline service + */ + public SearchResponseProcessorStep(SearchPipelineService searchPipelineService) { + this.searchPipelineService = searchPipelineService; + } + + @SuppressWarnings("unchecked") + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture searchResponseProcessorFuture = PlainActionFuture.newFuture(); + + Set requiredKeys = Set.of(TYPE, PROCESSOR_CONFIG, TAG, DESCRIPTION_FIELD, SEARCH_REQUEST, SEARCH_RESPONSE); + Set optionalKeys = Collections.emptySet(); + + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + String type = (String) inputs.get(TYPE); + PipelineConfiguration pipelineConfiguration = new PipelineConfiguration( + "id", + new BytesArray((String) inputs.get(PROCESSOR_CONFIG)), + MediaTypeRegistry.JSON + ); + String tag = (String) inputs.get(TAG); + String description = (String) inputs.get(DESCRIPTION_FIELD); + SearchRequest searchRequest = (SearchRequest) inputs.get(SEARCH_REQUEST); + SearchResponse searchResponse = (SearchResponse) inputs.get(SEARCH_RESPONSE); + + try { + // Retrieve response processor factories + Map> responseProcessors = searchPipelineService + .getResponseProcessorFactories(); + + // Create an instance of the processor + SearchResponseProcessor processor = responseProcessors.get(type) + .create( + responseProcessors, + tag, + description, + false, + pipelineConfiguration.getConfigAsMap(), + new PipelineContext(PipelineSource.UPDATE_PIPELINE) + ); + + // Temp : rerank processor only invokes processResponse async, doesnt need request context + processor.processResponseAsync(searchRequest, searchResponse, null, ActionListener.wrap(processedSearchResponse -> { + searchResponseProcessorFuture.onResponse( + new WorkflowData( + Map.ofEntries(Map.entry(SEARCH_RESPONSE, processedSearchResponse)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { searchResponseProcessorFuture.onFailure(exception); })); + + } catch (Exception e) { + searchResponseProcessorFuture.onFailure(e); + } + return searchResponseProcessorFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 884968abc..7a319815a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -21,6 +21,7 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; @@ -83,6 +84,14 @@ public WorkflowProcessSorter( this.client = client; } + /** + * Updates the workflow step factory + * @param searchPipelineService the search pipeline service + */ + public void updateWorkflowStepFactory(SearchPipelineService searchPipelineService) { + this.workflowStepFactory.updateWorkflowStepFactory(searchPipelineService); + } + /** * Sort a workflow into a topologically sorted list of process nodes. * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index ad338a715..b3b42ac81 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.model.WorkflowStepValidator; import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.threadpool.ThreadPool; import java.util.Collections; @@ -47,11 +48,13 @@ import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; +import static org.opensearch.flowframework.common.CommonValue.PROCESSOR_CONFIG; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.SEARCH_REQUEST; import static org.opensearch.flowframework.common.CommonValue.SEARCH_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; +import static org.opensearch.flowframework.common.CommonValue.TAG; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.CommonValue.URL; @@ -120,6 +123,16 @@ public WorkflowStepFactory( stepMap.put(CreateSearchPipelineStep.NAME, () -> new CreateSearchPipelineStep(client, flowFrameworkIndicesHandler)); } + /** + * Updates the step map with search response processors + * @param searchPipelineService the search pipeline service + */ + public void updateWorkflowStepFactory(SearchPipelineService searchPipelineService) { + if (!this.stepMap.containsKey(SearchResponseProcessorStep.NAME)) { + this.stepMap.put(SearchResponseProcessorStep.NAME, () -> new SearchResponseProcessorStep(searchPipelineService)); + } + } + /** * Enum encapsulating the different step names, their inputs, outputs, required plugin and timeout of the step */ @@ -246,7 +259,15 @@ public enum WorkflowSteps { List.of(SEARCH_REQUEST, SEARCH_RESPONSE), Collections.emptyList(), null + ), + /** Search Response Processor Step */ + SEARCH_RESPONSE_PROCESSOR( + SearchResponseProcessorStep.NAME, + List.of(TYPE, PROCESSOR_CONFIG, TAG, DESCRIPTION_FIELD, SEARCH_REQUEST, SEARCH_RESPONSE), + List.of(SEARCH_RESPONSE), + Collections.emptyList(), + null ); private final String workflowStepName; diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 401ddbe9a..4623432d8 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 5, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(9, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(9, ffp.getActions().size()); + assertEquals(10, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(10, ffp.getActions().size()); assertEquals(3, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 5b7848c9b..15a41d3bb 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -14,6 +14,7 @@ import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -46,7 +47,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(19, validator.getWorkflowStepValidators().size()); + assertEquals(20, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); @@ -118,6 +119,7 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { flowFrameworkSettings, client ); + workflowStepFactory.updateWorkflowStepFactory(mock(SearchPipelineService.class)); WorkflowValidator workflowValidator = workflowStepFactory.getWorkflowValidator(); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 11b620f3d..ed9be0ef7 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.plugins.PluginsService; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -102,7 +103,8 @@ public void setUp() throws Exception { flowFrameworkIndicesHandler, flowFrameworkSettings, client, - pluginsService + pluginsService, + mock(SearchPipelineService.class) ) ); // client = mock(Client.class);