Skip to content

Commit

Permalink
Substitute REST path or body parameters in Workflow Steps (#525)
Browse files Browse the repository at this point in the history
* Include params map in WorkflowRequest when provisioning

Signed-off-by: Daniel Widdis <[email protected]>

* Pass params to ProcessNode

Signed-off-by: Daniel Widdis <[email protected]>

* Pass params to WorkflowSteps

Signed-off-by: Daniel Widdis <[email protected]>

* Substitute params

Signed-off-by: Daniel Widdis <[email protected]>

* Add change log

Signed-off-by: Daniel Widdis <[email protected]>

* Improve param consuming checks, add coverage

Signed-off-by: Daniel Widdis <[email protected]>

* Allow specifying key-value pairs in body

Signed-off-by: Daniel Widdis <[email protected]>

* Update title in change log

Signed-off-by: Daniel Widdis <[email protected]>

* Refactor param and content map generation to a new method

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Feb 22, 2024
1 parent 24bf51a commit 3019fb8
Show file tree
Hide file tree
Showing 49 changed files with 446 additions and 90 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x)
### Features
### Enhancements
- Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525))

### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
Expand Down Expand Up @@ -75,6 +78,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
String workflowId = request.param(WORKFLOW_ID);
String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" });
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
final List<String> validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW);
// If provisioning, consume all other params and pass to provision transport action
Map<String, String> params = provision
? request.params()
.keySet()
.stream()
.filter(k -> !validCreateParams.contains(k))
.collect(Collectors.toMap(Function.identity(), request::param))
: request.params()
.entrySet()
.stream()
.filter(e -> !validCreateParams.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
FlowFrameworkException ffe = new FlowFrameworkException(
"This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
Expand All @@ -84,12 +100,24 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
if (!provision && !params.isEmpty()) {
// Consume params and content so custom exception is processed
params.keySet().stream().forEach(request::param);
request.content();
FlowFrameworkException ffe = new FlowFrameworkException(
"Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.",
RestStatus.BAD_REQUEST
);
return channel -> channel.sendResponse(
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Template template = Template.parse(parser);

WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
Expand All @@ -27,7 +28,11 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -69,23 +74,19 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
try {
Map<String, String> params = parseParamsAndContent(request);
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
}
// Validate content
if (request.hasContent()) {
// BaseRestHandler will give appropriate error message
return channel -> channel.sendResponse(null);
}
// Validate params
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand All @@ -108,4 +109,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}

private Map<String, String> parseParamsAndContent(RestRequest request) {
// Get any other params from path
Map<String, String> params = request.params()
.keySet()
.stream()
.filter(k -> !WORKFLOW_ID.equals(k))
.collect(Collectors.toMap(Function.identity(), request::param));
// If body is included get any params from body
if (request.hasContent()) {
try (XContentParser parser = request.contentParser()) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String key = parser.currentName();
if (params.containsKey(key)) {
throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST);
}
if (parser.nextToken() != XContentParser.Token.VALUE_STRING) {
throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST);
}
params.put(key, parser.text());
}
} catch (IOException e) {
throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST);
}
}
return params;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.transport.TransportService;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -282,7 +283,7 @@ void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionList

private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null, Collections.emptyMap());
workflowProcessSorter.validate(sortedNodes, pluginsService);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ private void executeDeprovisionSequence(
deprovisionStepId,
workflowStepFactory.createStep(deprovisionStep),
Collections.emptyMap(),
Collections.emptyMap(),
new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId),
Collections.emptyList(),
this.threadPool,
Expand Down Expand Up @@ -194,6 +195,7 @@ private void executeDeprovisionSequence(
pn.id(),
workflowStepFactory.createStep(pn.workflowStep().getName()),
pn.previousNodeInputs(),
pn.params(),
pn.input(),
pn.predecessors(),
this.threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(
provisionWorkflow,
workflowId,
request.getParams()
);
workflowProcessSorter.validate(provisionProcessSequence, pluginsService);

flowFrameworkIndicesHandler.isWorkflowNotStarted(workflowId, workflowIsNotStarted -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.flowframework.model.Template;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;

/**
* Transport Request to create, provision, and deprovision a workflow
Expand Down Expand Up @@ -43,12 +45,27 @@ public class WorkflowRequest extends ActionRequest {
private boolean provision;

/**
* Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null
* Params map
*/
private Map<String, String> params;

/**
* Instantiates a new WorkflowRequest, set validation to all, no provisioning
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) {
this(workflowId, template, new String[] { "all" }, false);
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap());
}

/**
* Instantiates a new WorkflowRequest with params map, set validation to all, provisioning to true
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
* @param params The parameters from the REST path
*/
public WorkflowRequest(String workflowId, @Nullable Template template, Map<String, String> params) {
this(workflowId, template, new String[] { "all" }, true, params);
}

/**
Expand All @@ -57,12 +74,23 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param template the use case template which describes the workflow
* @param validation flag to indicate if validation is necessary
* @param provision flag to indicate if provision is necessary
* @param params map of REST path params. If provision is false, must be an empty map.
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision) {
public WorkflowRequest(
@Nullable String workflowId,
@Nullable Template template,
String[] validation,
boolean provision,
Map<String, String> params
) {
this.workflowId = workflowId;
this.template = template;
this.validation = validation;
this.provision = provision;
if (!provision && !params.isEmpty()) {
throw new IllegalArgumentException("Params may only be included when provisioning.");
}
this.params = params;
}

/**
Expand All @@ -77,6 +105,7 @@ public WorkflowRequest(StreamInput in) throws IOException {
this.template = templateJson == null ? null : Template.parse(templateJson);
this.validation = in.readStringArray();
this.provision = in.readBoolean();
this.params = this.provision ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap();
}

/**
Expand Down Expand Up @@ -113,13 +142,24 @@ public boolean isProvision() {
return this.provision;
}

/**
* Gets the params map
* @return the params map
*/
public Map<String, String> getParams() {
return Map.copyOf(this.params);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(workflowId);
out.writeOptionalString(template == null ? null : template.toJson());
out.writeStringArray(validation);
out.writeBoolean(provision);
if (provision) {
out.writeMap(params, StreamOutput::writeString, StreamOutput::writeString);
}
}

@Override
Expand Down
18 changes: 13 additions & 5 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ public static Map<String, String> getStringToStringMap(Object map, String fieldN
* @param currentNodeInputs Input params and content for this node, from workflow parsing
* @param outputs WorkflowData content of previous steps
* @param previousNodeInputs Input params for this node that come from previous steps
* @param params Params that came from REST path
* @return A map containing the requiredInputKeys with their corresponding values,
* and optionalInputKeys with their corresponding values if present.
* Throws a {@link FlowFrameworkException} if a required key is not present.
Expand All @@ -257,7 +258,8 @@ public static Map<String, Object> getInputsFromPreviousSteps(
Set<String> optionalInputKeys,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
// Mutable set to ensure all required keys are used
Set<String> requiredKeys = new HashSet<>(requiredInputKeys);
Expand Down Expand Up @@ -308,11 +310,11 @@ public static Map<String, Object> getInputsFromPreviousSteps(
Map<String, Object> valueMap = (Map<String, Object>) value;
value = valueMap.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs)));
.collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs, params)));
} else if (value instanceof List) {
value = ((List<?>) value).stream().map(v -> conditionallySubstitute(v, outputs)).collect(Collectors.toList());
value = ((List<?>) value).stream().map(v -> conditionallySubstitute(v, outputs, params)).collect(Collectors.toList());
} else {
value = conditionallySubstitute(value, outputs);
value = conditionallySubstitute(value, outputs, params);
}
// Add value to inputs and mark that a required key was present
inputs.put(key, value);
Expand All @@ -336,15 +338,21 @@ public static Map<String, Object> getInputsFromPreviousSteps(
return inputs;
}

private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs) {
private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs, Map<String, String> params) {
if (value instanceof String) {
Matcher m = SUBSTITUTION_PATTERN.matcher((String) value);
if (m.matches()) {
// Try matching a previous step+value pair
WorkflowData data = outputs.get(m.group(1));
if (data != null && data.getContent().containsKey(m.group(2))) {
return data.getContent().get(m.group(2));
}
}
// Replace all params if present
for (Entry<String, String> e : params.entrySet()) {
String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}";
value = ((String) value).replaceAll(regex, e.getValue());
}
}
return value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {

PlainActionFuture<WorkflowData> registerLocalModelFuture = PlainActionFuture.newFuture();
Expand All @@ -90,7 +91,8 @@ public PlainActionFuture<WorkflowData> execute(
getOptionalKeys(),
currentNodeInputs,
outputs,
previousNodeInputs
previousNodeInputs,
params
);

// Extract common fields of OS provided text-embedding, sparse encoding and custom models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> createConnectorFuture = PlainActionFuture.newFuture();

Expand Down Expand Up @@ -138,7 +139,8 @@ public void onFailure(Exception e) {
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
previousNodeInputs,
params
);

String name = (String) inputs.get(NAME_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> createIndexFuture = PlainActionFuture.newFuture();
ActionListener<CreateIndexResponse> actionListener = new ActionListener<>() {
Expand Down
Loading

0 comments on commit 3019fb8

Please sign in to comment.