diff --git a/CHANGELOG.md b/CHANGELOG.md index 02916436a..48d5d819a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - Added create ingest pipeline step ([#558](https://github.com/opensearch-project/flow-framework/pull/558)) - Added create search pipeline step ([#569](https://github.com/opensearch-project/flow-framework/pull/569)) - Added create index step ([#574](https://github.com/opensearch-project/flow-framework/pull/574)) +- Added default use cases ([#583](https://github.com/opensearch-project/flow-framework/pull/583)) ### Enhancements - Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) diff --git a/build.gradle b/build.gradle index e81d2b5e7..45837ca39 100644 --- a/build.gradle +++ b/build.gradle @@ -180,6 +180,7 @@ dependencies { // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" + secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" configurations.all { diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index bde91b55d..d3960d90b 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -72,6 +72,8 @@ private CommonValue() {} public static final String PROVISION_WORKFLOW = "provision"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ public static final String WORKFLOW_STEP = "workflow_step"; + /** The param name for default use case, used by the create workflow API */ + public static final String USE_CASE = "use_case"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java new file mode 100644 index 000000000..8c12f7c43 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; + +/** + * Enum encapsulating the different default use cases and templates we have stored + */ +public enum DefaultUseCases { + + /** defaults file and substitution ready template for OpenAI embedding model */ + OPEN_AI_EMBEDDING_MODEL_DEPLOY( + "open_ai_embedding_model_deploy", + "defaults/open-ai-embedding-defaults.json", + "substitutionTemplates/deploy-remote-model-template.json" + ), + /** defaults file and substitution ready template for cohere embedding model */ + COHERE_EMBEDDING_MODEL_DEPLOY( + "cohere-embedding_model_deploy", + "defaults/cohere-embedding-defaults.json", + "substitutionTemplates/deploy-remote-model-template-extra-params.json" + ), + /** defaults file and substitution ready template for local neural sparse model and ingest pipeline*/ + LOCAL_NEURAL_SPARSE_SEARCH( + "local_neural_sparse_search", + "defaults/local-sparse-search-defaults.json", + "substitutionTemplates/neural-sparse-local-template.json" + ); + + private final String useCaseName; + private final String defaultsFile; + private final String substitutionReadyFile; + private static final Logger logger = LogManager.getLogger(DefaultUseCases.class); + + DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) { + this.useCaseName = useCaseName; + this.defaultsFile = defaultsFile; + this.substitutionReadyFile = substitutionReadyFile; + } + + /** + * Returns the useCaseName for the given enum Constant + * @return the useCaseName of this use case. + */ + public String getUseCaseName() { + return useCaseName; + } + + /** + * Returns the defaultsFile for the given enum Constant + * @return the defaultsFile of this for the given useCase. + */ + public String getDefaultsFile() { + return defaultsFile; + } + + /** + * Returns the substitutionReadyFile for the given enum Constant + * @return the substitutionReadyFile of the given useCase + */ + public String getSubstitutionReadyFile() { + return substitutionReadyFile; + } + + /** + * Gets the defaultsFile based on the given use case. + * @param useCaseName name of the given use case + * @return the defaultsFile for that usecase + * @throws FlowFrameworkException if the use case doesn't exist in enum + */ + public static String getDefaultsFileByUseCaseName(String useCaseName) throws FlowFrameworkException { + if (useCaseName != null && !useCaseName.isEmpty()) { + for (DefaultUseCases usecase : values()) { + if (useCaseName.equals(usecase.getUseCaseName())) { + return usecase.getDefaultsFile(); + } + } + } + logger.error("Unable to find defaults file for use case: {}", useCaseName); + throw new FlowFrameworkException("Unable to find defaults file for use case: " + useCaseName, RestStatus.BAD_REQUEST); + } + + /** + * Gets the substitutionReadyFile based on the given use case + * @param useCaseName name of the given use case + * @return the substitutionReadyFile which has the template + * @throws FlowFrameworkException if the use case doesn't exist in enum + */ + public static String getSubstitutionReadyFileByUseCaseName(String useCaseName) throws FlowFrameworkException { + if (useCaseName != null && !useCaseName.isEmpty()) { + for (DefaultUseCases useCase : values()) { + if (useCase.getUseCaseName().equals(useCaseName)) { + return useCase.getSubstitutionReadyFile(); + } + } + } + logger.error("Unable to find substitution ready file for use case: {}", useCaseName); + throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST); + } +} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 09fec81e1..7189d8962 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -17,16 +17,19 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.DefaultUseCases; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -35,6 +38,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; @@ -78,6 +82,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); + String useCase = request.param(USE_CASE); // If provisioning, consume all other params and pass to provision transport action Map params = provision ? request.params() @@ -112,11 +117,63 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } try { - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Template template = Template.parse(parser); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params); + Template template; + Map useCaseDefaultsMap = Collections.emptyMap(); + if (useCase != null) { + String useCaseTemplateFileInStringFormat = ParseUtils.resourceToString( + "/" + DefaultUseCases.getSubstitutionReadyFileByUseCaseName(useCase) + ); + String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase); + useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath); + + if (request.hasContent()) { + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Map userDefaults = ParseUtils.parseStringToStringMap(parser); + // updates the default params with anything user has given that matches + for (Map.Entry userDefaultsEntry : userDefaults.entrySet()) { + String key = userDefaultsEntry.getKey(); + String value = userDefaultsEntry.getValue(); + if (useCaseDefaultsMap.containsKey(key)) { + useCaseDefaultsMap.put(key, value); + } + } + } catch (Exception ex) { + RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex); + String errorMessage = "failure parsing request body when a use case is given"; + logger.error(errorMessage, ex); + throw new FlowFrameworkException(errorMessage, status); + } + + } + + useCaseTemplateFileInStringFormat = (String) ParseUtils.conditionallySubstitute( + useCaseTemplateFileInStringFormat, + null, + useCaseDefaultsMap + ); + + XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson); + template = Template.parse(parserTestJson); + + } else { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + template = Template.parse(parser); + } + + WorkflowRequest workflowRequest = new WorkflowRequest( + workflowId, + template, + validation, + provision, + params, + useCase, + useCaseDefaultsMap + ); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); @@ -134,11 +191,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), errorMessage)); } })); + } catch (FlowFrameworkException e) { + logger.error("failed to prepare rest request", e); return channel -> channel.sendResponse( new BytesRestResponse(e.getRestStatus(), e.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); - } catch (IOException e) { + } catch (Exception e) { + logger.error("failed to prepare rest request", e); FlowFrameworkException ex = new FlowFrameworkException( "IOException: template content invalid for specified Content-Type.", RestStatus.BAD_REQUEST diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 5b3c3c0d8..da0b17239 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -49,13 +49,23 @@ public class WorkflowRequest extends ActionRequest { */ private Map params; + /** + * use case flag + */ + private String useCase; + + /** + * Deafult params map from use case + */ + private Map defaultParams; + /** * Instantiates a new WorkflowRequest, set validation to all, no provisioning * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap()); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap()); } /** @@ -65,7 +75,18 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params); + this(workflowId, template, new String[] { "all" }, true, params, null, Collections.emptyMap()); + } + + /** + * Instantiates a new WorkflowRequest with params map, set validation to all, provisioning to true + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param useCase the default use case give by user + * @param defaultParams The parameters from the REST body when a use case is given + */ + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String useCase, Map defaultParams) { + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), useCase, defaultParams); } /** @@ -75,13 +96,17 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param validation flag to indicate if validation is necessary * @param provision flag to indicate if provision is necessary * @param params map of REST path params. If provision is false, must be an empty map. + * @param useCase default use case given + * @param defaultParams the params to be used in the substitution based on the default use case. */ public WorkflowRequest( @Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision, - Map params + Map params, + String useCase, + Map defaultParams ) { this.workflowId = workflowId; this.template = template; @@ -91,6 +116,8 @@ public WorkflowRequest( throw new IllegalArgumentException("Params may only be included when provisioning."); } this.params = params; + this.useCase = useCase; + this.defaultParams = defaultParams; } /** @@ -150,6 +177,22 @@ public Map getParams() { return Map.copyOf(this.params); } + /** + * Gets the use case + * @return the use case + */ + public String getUseCase() { + return this.useCase; + } + + /** + * Gets the params map + * @return the params map + */ + public Map getDefaultParams() { + return Map.copyOf(this.defaultParams); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 140f0a4af..224304016 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -61,6 +61,8 @@ public class ParseUtils { private ParseUtils() {} + private static final ObjectMapper mapper = new ObjectMapper(); + /** * Converts a JSON string into an XContentParser * @@ -342,11 +344,18 @@ public static Map getInputsFromPreviousSteps( return inputs; } - private static Object conditionallySubstitute(Object value, Map outputs, Map params) { + /** + * Executes substitution on the given value by looking at any matching values in either the ouputs or params map + * @param value the Object that will have the substitution done on + * @param outputs potential location of values to be substituted in + * @param params potential location of values to be subsituted in + * @return the substituted object back + */ + public static Object conditionallySubstitute(Object value, Map outputs, Map params) { if (value instanceof String) { Matcher m = SUBSTITUTION_PATTERN.matcher((String) value); StringBuilder result = new StringBuilder(); - while (m.find()) { + while (m.find() && outputs != null) { // outputs content map contains values for previous node input (e.g: deploy_openai_model.model_id) // Check first if the substitution is looking for the same key, value pair and if yes // then replace it with the key value pair in the inputs map @@ -364,10 +373,15 @@ private static Object conditionallySubstitute(Object value, Map e : params.entrySet()) { - String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}"; - value = ((String) value).replaceAll(regex, e.getValue()); + if (params != null) { + for (Map.Entry e : params.entrySet()) { + String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}"; + String replacement = e.getValue(); + + // Special handling for JSON strings that contain placeholders (connectors action) + replacement = Matcher.quoteReplacement(replacement.replace("\"", "\\\"")); + value = ((String) value).replaceAll(regex, replacement); + } } } return value; @@ -380,9 +394,20 @@ private static Object conditionallySubstitute(Object value, Map map) throws JsonProcessingException { - ObjectMapper mapper = new ObjectMapper(); // Convert the map to a JSON string String mappedString = mapper.writeValueAsString(map); return mappedString; } + + /** + * Generates a String to String map based on a Json File + * @param path file path + * @return instance of the string + * @throws JsonProcessingException JsonProcessingException from Jackson for issues processing map + */ + public static Map parseJsonFileToStringToStringMap(String path) throws IOException { + String jsonContent = resourceToString(path); + Map mappedJsonFile = mapper.readValue(jsonContent, Map.class); + return mappedJsonFile; + } } diff --git a/src/main/resources/defaults/cohere-embedding-defaults.json b/src/main/resources/defaults/cohere-embedding-defaults.json new file mode 100644 index 000000000..e36578b1c --- /dev/null +++ b/src/main/resources/defaults/cohere-embedding-defaults.json @@ -0,0 +1,18 @@ +{ + "template.name": "deploy-cohere-model", + "template.description": "deploying cohere embedding model", + "create_connector.name": "cohere-embedding-connector", + "create_connector.description": "The connector to Cohere's public embed API", + "create_connector.protocol": "http", + "create_connector.model": "embed-english-v3.0", + "create_connector.input_type": "search_document", + "create_connector.truncate": "end", + "create_connector.endpoint": "api.openai.com", + "create_connector.credential.key": "123", + "create_connector.actions.url": "https://api.cohere.ai/v1/embed", + "create_connector.actions.request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"${parameters.truncate}\", \"model\": \"${parameters.model}\", \"input_type\": \"${parameters.input_type}\" }", + "create_connector.actions.pre_process_function": "connector.pre_process.cohere.embedding", + "create_connector.actions.post_process_function": "connector.post_process.cohere.embedding", + "register_remote_model.name": "Cohere english embed model", + "register_remote_model.description": "cohere-embedding-model" +} diff --git a/src/main/resources/defaults/local-sparse-search-defaults.json b/src/main/resources/defaults/local-sparse-search-defaults.json new file mode 100644 index 000000000..cde9291f2 --- /dev/null +++ b/src/main/resources/defaults/local-sparse-search-defaults.json @@ -0,0 +1,17 @@ +{ + "template.name": "local-model-neural-sparse-search", + "template.description": "setting up neural sparse search with local model", + "register_local_sparse_encoding_model.name": "neural-sparse/opensearch-neural-sparse-tokenizer-v1-v2", + "register_local_sparse_encoding_model.description": "This is a neural sparse tokenizer model: It tokenize input sentence into tokens and assign pre-defined weight from IDF to each. It serves only in query.", + "register_local_sparse_encoding_model.node_timeout": "60s", + "register_local_sparse_encoding_model.model_format": "TORCH_SCRIPT", + "register_local_sparse_encoding_model.function_name": "SPARSE_TOKENIZE", + "register_local_sparse_encoding_model.model_content_hash_value": "b3487da9c58ac90541b720f3b367084f271d280c7f3bdc3e6d9c9a269fb31950", + "register_local_sparse_encoding_model.url": "https://artifacts.opensearch.org/models/ml-models/amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1/1.0.0/torch_script/opensearch-neural-sparse-tokenizer-v1-1.0.0.zip", + "register_local_sparse_encoding_model.deploy": "true", + "create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline-sparse", + "create_ingest_pipeline.description": "A sparse encoding ingest pipeline", + "create_ingest_pipeline.text_embedding.field_map.input": "passage_text", + "create_ingest_pipeline.text_embedding.field_map.output": "passage_embedding", + "create_index.name": "my-nlp-index" +} diff --git a/src/main/resources/defaults/open-ai-embedding-defaults.json b/src/main/resources/defaults/open-ai-embedding-defaults.json new file mode 100644 index 000000000..59fed86de --- /dev/null +++ b/src/main/resources/defaults/open-ai-embedding-defaults.json @@ -0,0 +1,18 @@ +{ + "open_ai_embedding_deploy": { + "template.name": "deploy-openai-model", + "template.description": "deploying openAI embedding model", + "create_connector.name": "OpenAI-embedding-connector", + "create_connector.description": "Connector to public OpenAI model", + "create_connector.protocol": "http", + "create_connector.model": "text-embedding-ada-002", + "create_connector.endpoint": "api.openai.com", + "create_connector.credential.key": "123", + "create_connector.actions.url": "https://api.openai.com/v1/embeddings", + "create_connector.actions.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }", + "create_connector.actions.pre_process_function": "connector.pre_process.openai.embedding", + "create_connector.actions.post_process_function": "connector.post_process.openai.embedding", + "register_remote_model_1.name": "OpenAI embedding model", + "register_remote_model_1.description": "openai-embedding-model" + } +} diff --git a/src/main/resources/mappings/deploy-remote-model-template-draft.json b/src/main/resources/mappings/deploy-remote-model-template-draft.json new file mode 100644 index 000000000..a2f80a8c4 --- /dev/null +++ b/src/main/resources/mappings/deploy-remote-model-template-draft.json @@ -0,0 +1,77 @@ +{ + "name": "{template.name}", + "description": "{template.description}", + "use_case": "DEPLOY_MODEL", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "${{create_connector_1}}", + "description": "${{create_connector_1.description}}", + "version": "1", + "protocol": "${{create_connector_1.protocol}}", + "parameters": { + "endpoint": "${{create_connector_1.endpoint}}", + "model": "${{create_connector_1.model}}" + }, + "credential": { + "key": "${{create_connector_1.credential.key}}", + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://api.openai.com/v1/embeddings", + "headers": { + "Authorization": "Bearer ${credential.openAI_key}" + }, + "request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }", + "pre_process_function": "connector.pre_process.openai.embedding", + "post_process_function": "connector.post_process.openai.embedding" + } + ] + } + }, + { + "id": "register_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_connector_step_1": "parameters" + }, + "user_inputs": { + "name": "${register_remote_model.name}", + "function_name": "remote", + "description": "${register_remote_model.description}" + } + }, + { + "id": "deploy_model", + "type": "deploy_model", + "previous_node_inputs": { + "register_model_1": "model_id" + } + } + ], + "edges": [ + { + "source": "create_connector", + "dest": "register_model" + }, + { + "source": "register_model", + "dest": "deploy_model" + } + ] + } + } +} diff --git a/src/main/resources/mappings/open-ai-defaults.json b/src/main/resources/mappings/open-ai-defaults.json new file mode 100644 index 000000000..88f200e32 --- /dev/null +++ b/src/main/resources/mappings/open-ai-defaults.json @@ -0,0 +1,36 @@ +{ + "deploy-remote-model-defaults": [ + { + "openai_embedding_deploy": { + "template.name": "deploy-openai-model", + "template.description": "deploying openAI embedding model", + "create_connector_1.name": "OpenAI-embedding-connector", + "create_connector_1.description": "Connector to public AI model service for GPT 3.5", + "create_connector_1.protocol": "http", + "create_connector_1.model": "gpt-3.5-turbo", + "create_connector_1.endpoint": "api.openai.com", + "create_connector_1.credential.key": "123", + "create_connector_1.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }", + "create_connector_1.pre_process_function": "connector.pre_process.openai.embedding", + "create_connector_1.post_process_function": "connector.post_process.openai.embedding", + "register_remote_model_1.name": "test-description" + } + }, + { + "cohere_embedding_deploy": { + "template.name": "deploy-cohere-embedding-model", + "template.description": "deploying cohere embedding model", + "create_connector_1.name": "cohere-embedding-connector", + "create_connector_1.description": "Connector to public AI model service for GPT 3.5", + "create_connector_1.protocol": "http", + "create_connector_1.model": "gpt-3.5-turbo", + "create_connector_1.endpoint": "api.openai.com", + "create_connector_1.credential.key": "123", + "create_connector_1.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }", + "create_connector_1.pre_process_function": "connector.pre_process.openai.embedding", + "create_connector_1.post_process_function": "connector.post_process.openai.embedding", + "register_remote_model_1.name": "test-description" + } + } + ] +} diff --git a/src/main/resources/substitutionTemplates/deploy-model-semantic-search-template-v1.json b/src/main/resources/substitutionTemplates/deploy-model-semantic-search-template-v1.json new file mode 100644 index 000000000..ae90693d3 --- /dev/null +++ b/src/main/resources/substitutionTemplates/deploy-model-semantic-search-template-v1.json @@ -0,0 +1,124 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "DEPLOY_MODEL", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "${{create_connector_1}}", + "description": "${{create_connector_1.description}}", + "version": "1", + "protocol": "${{create_connector_1.protocol}}", + "parameters": { + "endpoint": "${{create_connector_1.endpoint}}", + "model": "${{create_connector_1.model}}", + "input_type": "search_document", + "truncate": "END" + }, + "credential": { + "key": "${{create_connector_1.credential.key}}" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "${{create_connector.actions.url}}", + "headers": { + "Authorization": "Bearer ${credential.key}", + "Request-Source": "unspecified:opensearch" + }, + "request_body": "${{create_connector.actions.request_body}}", + "pre_process_function": "${{create_connector.actions.pre_process_function}}", + "post_process_function": "${{create_connector.actions.post_process_function}}" + } + ] + } + }, + { + "id": "register_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_connector_step_1": "parameters" + }, + "user_inputs": { + "name": "${register_remote_model.name}", + "function_name": "remote", + "description": "${register_remote_model.description}" + } + }, + { + "id": "deploy_model", + "type": "deploy_model", + "previous_node_inputs": { + "register_model_1": "model_id" + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "deploy_openai_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "${{create_ingest_pipeline.pipeline_id}}", + "configurations": { + "description": "${{create_ingest_pipeline.description}}", + "processors": [ + { + "text_embedding": { + "model_id": "${{deploy_openai_model.model_id}}", + "field_map": { + "${{text_embedding.field_map.input}}": "${{text_embedding.field_map.input}}" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "${{create_index.name}}", + "configurations": { + "settings": { + "index": { + "number_of_shards": 2, + "number_of_replicas": 1, + "search.default_pipeline" : "${{create_ingest_pipeline.pipeline_id}}" + } + }, + "mappings": { + "_doc": { + "properties": { + "age": { + "type": "integer" + } + } + } + }, + "aliases": { + "sample-alias1": {} + } + } + } + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/deploy-remote-model-template-extra-params.json b/src/main/resources/substitutionTemplates/deploy-remote-model-template-extra-params.json new file mode 100644 index 000000000..ade509666 --- /dev/null +++ b/src/main/resources/substitutionTemplates/deploy-remote-model-template-extra-params.json @@ -0,0 +1,80 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "${{create_connector.name}}", + "description": "${{create_connector.description}}", + "version": "1", + "protocol": "${{create_connector.protocol}}", + "parameters": { + "endpoint": "${{create_connector.endpoint}}", + "model": "${{create_connector.model}}", + "input_type": "search_document", + "truncate": "END" + }, + "credential": { + "key": "${{create_connector.credential.key}}" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "${{create_connector.actions.url}}", + "headers": { + "Authorization": "Bearer ${credential.key}", + "Request-Source": "unspecified:opensearch" + }, + "request_body": "${{create_connector.actions.request_body}}", + "pre_process_function": "${{create_connector.actions.pre_process_function}}", + "post_process_function": "${{create_connector.actions.post_process_function}}" + } + ] + } + }, + { + "id": "register_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_connector": "parameters" + }, + "user_inputs": { + "name": "${{register_remote_model.name}}", + "function_name": "remote", + "description": "${{register_remote_model.description}}" + } + }, + { + "id": "deploy_model", + "type": "deploy_model", + "previous_node_inputs": { + "register_model": "model_id" + } + } + ], + "edges": [ + { + "source": "create_connector", + "dest": "register_model" + }, + { + "source": "register_model", + "dest": "deploy_model" + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/deploy-remote-model-template.json b/src/main/resources/substitutionTemplates/deploy-remote-model-template.json new file mode 100644 index 000000000..bc1c9eebc --- /dev/null +++ b/src/main/resources/substitutionTemplates/deploy-remote-model-template.json @@ -0,0 +1,77 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "${{create_connector}}", + "description": "${{create_connector.description}}", + "version": "1", + "protocol": "${{create_connector.protocol}}", + "parameters": { + "endpoint": "${{create_connector.endpoint}}", + "model": "${{create_connector.model}}" + }, + "credential": { + "key": "${{create_connector.credential.key}}" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "${{create_connector.actions.url}}", + "headers": { + "Authorization": "Bearer ${credential.key}" + }, + "request_body": "${{create_connector.actions.request_body}}", + "pre_process_function": "${{create_connector.actions.pre_process_function}}", + "post_process_function": "${{create_connector.actions.post_process_function}}" + } + ] + } + }, + { + "id": "register_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_connector_step_1": "parameters" + }, + "user_inputs": { + "name": "${{register_remote_model.name}}", + "function_name": "remote", + "description": "${{register_remote_model.description}}" + } + }, + { + "id": "deploy_model", + "type": "deploy_model", + "previous_node_inputs": { + "register_model_1": "model_id" + } + } + ], + "edges": [ + { + "source": "create_connector", + "dest": "register_model" + }, + { + "source": "register_model", + "dest": "deploy_model" + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/neural-sparse-local-template.json b/src/main/resources/substitutionTemplates/neural-sparse-local-template.json new file mode 100644 index 000000000..372336bb8 --- /dev/null +++ b/src/main/resources/substitutionTemplates/neural-sparse-local-template.json @@ -0,0 +1,86 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "register_local_sparse_encoding_model", + "type": "register_local_sparse_encoding_model", + "user_inputs": { + "node_timeout": "60s", + "name": "neural-sparse/opensearch-neural-sparse-tokenizer-v1-v2", + "version": "1.0.0", + "description": "This is a neural sparse tokenizer model: It tokenize input sentence into tokens and assign pre-defined weight from IDF to each. It serves only in query.", + "model_format": "TORCH_SCRIPT", + "function_name": "SPARSE_TOKENIZE", + "model_content_hash_value": "b3487da9c58ac90541b720f3b367084f271d280c7f3bdc3e6d9c9a269fb31950", + "url": "https://artifacts.opensearch.org/models/ml-models/amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1/1.0.0/torch_script/opensearch-neural-sparse-tokenizer-v1-1.0.0.zip", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_local_sparse_encoding_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "${{create_ingest_pipeline.pipeline_id}}", + "configurations": { + "description": "${{create_ingest_pipeline.description}}", + "processors": [ + { + "sparse_encoding": { + "model_id": "${{register_local_sparse_encoding_model.model_id}}", + "field_map": { + "${{create_ingest_pipeline.text_embedding.field_map.input}}": "${{create_ingest_pipeline.text_embedding.field_map.output}}" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "${{create_index.name}}", + "configurations": { + "settings": { + "default_pipeline": "${{create_ingest_pipeline.pipeline_id}}" + }, + "mappings": { + "_doc": { + "properties": { + "id": { + "type": "text" + }, + "${{create_ingest_pipeline.text_embedding.field_map.output}}": { + "type": "rank_features" + }, + "${{create_ingest_pipeline.text_embedding.field_map.input}}": { + "type": "text" + } + } + } + } + } + } + } + ] + } + } +} diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 37eeb14c7..326d382ee 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -339,6 +339,24 @@ protected Response createWorkflow(RestClient client, Template template) throws E return TestHelpers.makeRequest(client, "POST", WORKFLOW_URI + "?validation=off", Collections.emptyMap(), template.toJson(), null); } + /** + * Helper method to invoke the Create Workflow Rest Action without validation + * @param client the rest client + * @param useCase the usecase to create + * @throws Exception if the request fails + * @return a rest response + */ + protected Response createWorkflowWithUseCase(RestClient client, String useCase) throws Exception { + return TestHelpers.makeRequest( + client, + "POST", + WORKFLOW_URI + "?validation=off&use_case=" + useCase, + Collections.emptyMap(), + "{}", + null + ); + } + /** * Helper method to invoke the Create Workflow Rest Action with provision * @param client the rest client diff --git a/src/test/java/org/opensearch/flowframework/common/DefaultUseCasesTests.java b/src/test/java/org/opensearch/flowframework/common/DefaultUseCasesTests.java new file mode 100644 index 000000000..b6dc72ebb --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/common/DefaultUseCasesTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.test.OpenSearchTestCase; + +public class DefaultUseCasesTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testGetDefaultsFileByValidUseCaseName() throws FlowFrameworkException { + String defaultsFile = DefaultUseCases.getDefaultsFileByUseCaseName("open_ai_embedding_model_deploy"); + assertEquals("defaults/open-ai-embedding-defaults.json", defaultsFile); + } + + public void testGetDefaultsFileByInvalidUseCaseName() throws FlowFrameworkException { + FlowFrameworkException e = assertThrows( + FlowFrameworkException.class, + () -> DefaultUseCases.getDefaultsFileByUseCaseName("invalid_use_case") + ); + } + + public void testGetSubstitutionTemplateByValidUseCaseName() throws FlowFrameworkException { + String templateFile = DefaultUseCases.getSubstitutionReadyFileByUseCaseName("open_ai_embedding_model_deploy"); + assertEquals("substitutionTemplates/deploy-remote-model-template.json", templateFile); + } + + public void testGetSubstitutionTemplateByInvalidUseCaseName() throws FlowFrameworkException { + FlowFrameworkException e = assertThrows( + FlowFrameworkException.class, + () -> DefaultUseCases.getSubstitutionReadyFileByUseCaseName("invalid_use_case") + ); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 69391cbd1..8db37d83d 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -399,4 +399,45 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception { } + public void testDefaultCohereUseCase() throws Exception { + + // Using a 3 step template to create a connector, register remote model and deploy model + Template template = TestHelpers.createTemplateFromFile("ingest-search-pipeline-template.json"); + + // Hit Create Workflow API with original template + Response response = createWorkflowWithUseCase(client(), "cohere-embedding_model_deploy"); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); + + // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status + if (!indexExistsWithAdminClient(".plugins-ml-config")) { + assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); + response = provisionWorkflow(client(), workflowId); + } else { + response = provisionWorkflow(client(), workflowId); + } + + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(client(), workflowId, 30); + + List expectedStepNames = List.of("create_connector", "register_remote_model", "deploy_model"); + + List workflowStepNames = resourcesCreated.stream() + .peek(resourceCreated -> assertNotNull(resourceCreated.resourceId())) + .map(ResourceCreated::workflowStepName) + .collect(Collectors.toList()); + for (String expectedName : expectedStepNames) { + assertTrue(workflowStepNames.contains(expectedName)); + } + + // This template should create 5 resources, connector_id, registered model_id, deployed model_id and pipelineId + assertEquals(3, resourcesCreated.size()); + } + } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 250d00e87..3381bbbec 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.DefaultUseCases; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -34,6 +35,7 @@ import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -134,6 +136,40 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception ); } + public void testCreateWorkflowRequestWithUseCaseButNoProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName())) + .withContent(new BytesArray(""), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + + public void testCreateWorkflowRequestWithUseCaseAndContent() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName())) + .withContent(new BytesArray("{\"key\":\"step\"}"), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + public void testInvalidCreateWorkflowRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index f5e57c588..11b620f3d 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -211,7 +211,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), null, Collections.emptyMap()); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -248,7 +248,15 @@ public void onFailure(Exception e) { public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap() + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -279,7 +287,15 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap() + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -410,7 +426,15 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Collections.emptyMap(), + null, + Collections.emptyMap() + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -463,7 +487,15 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Collections.emptyMap(), + null, + Collections.emptyMap() + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index 98f9a1499..119e55f46 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -143,10 +143,55 @@ public void testWorkflowRequestWithParams() throws IOException { assertEquals("bar", workflowRequest.getParams().get("foo")); } + public void testWorkflowRequestWithUseCase() throws IOException { + WorkflowRequest workflowRequest = new WorkflowRequest("123", template, "cohere-embedding_model_deploy", Collections.emptyMap()); + assertNotNull(workflowRequest.getWorkflowId()); + assertEquals(template, workflowRequest.getTemplate()); + assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertTrue(workflowRequest.getDefaultParams().isEmpty()); + assertEquals(workflowRequest.getUseCase(), "cohere-embedding_model_deploy"); + + BytesStreamOutput out = new BytesStreamOutput(); + workflowRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(workflowRequest.getTemplate().toString(), streamInputRequest.getTemplate().toString()); + assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertTrue(workflowRequest.getDefaultParams().isEmpty()); + assertEquals(workflowRequest.getUseCase(), "cohere-embedding_model_deploy"); + } + + public void testWorkflowRequestWithUseCaseAndParamsInBody() throws IOException { + WorkflowRequest workflowRequest = new WorkflowRequest("123", template, "cohere-embedding_model_deploy", Map.of("step", "model")); + assertNotNull(workflowRequest.getWorkflowId()); + assertEquals(template, workflowRequest.getTemplate()); + assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertEquals(workflowRequest.getDefaultParams().get("step"), "model"); + + BytesStreamOutput out = new BytesStreamOutput(); + workflowRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(workflowRequest.getTemplate().toString(), streamInputRequest.getTemplate().toString()); + assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertEquals(workflowRequest.getDefaultParams().get("step"), "model"); + + } + public void testWorkflowRequestWithParamsNoProvision() throws IOException { IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar")) + () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), null, Collections.emptyMap()) ); assertEquals("Params may only be included when provisioning.", ex.getMessage()); } diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 06aaf45d9..7ae057d24 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -88,6 +90,50 @@ public void testParseArbitraryStringToObjectMapToString() throws IOException { assertEquals("{\"test-1\":{\"test-1\":\"test-1\"}}", parsedMap); } + public void testConditionallySubstituteWithNoPlaceholders() { + String input = "This string has no placeholders"; + Map outputs = new HashMap<>(); + Map params = new HashMap<>(); + + Object result = ParseUtils.conditionallySubstitute(input, outputs, params); + + assertEquals("This string has no placeholders", result); + } + + public void testConditionallySubstituteWithUnmatchedPlaceholders() { + String input = "This string has unmatched ${{placeholder}}"; + Map outputs = new HashMap<>(); + Map params = new HashMap<>(); + + Object result = ParseUtils.conditionallySubstitute(input, outputs, params); + + assertEquals("This string has unmatched ${{placeholder}}", result); + } + + public void testConditionallySubstituteWithOutputsSubstitution() { + String input = "This string contains ${{node.step}}"; + Map outputs = new HashMap<>(); + Map params = new HashMap<>(); + Map contents = new HashMap<>(Collections.emptyMap()); + contents.put("step", "model_id"); + WorkflowData data = new WorkflowData(contents, params, "test", "test"); + outputs.put("node", data); + Object result = ParseUtils.conditionallySubstitute(input, outputs, params); + assertEquals("This string contains model_id", result); + } + + public void testConditionallySubstituteWithParamsSubstitution() { + String input = "This string contains ${{node}}"; + Map outputs = new HashMap<>(); + Map params = new HashMap<>(); + params.put("node", "step"); + Map contents = new HashMap<>(Collections.emptyMap()); + WorkflowData data = new WorkflowData(contents, params, "test", "test"); + outputs.put("node", data); + Object result = ParseUtils.conditionallySubstitute(input, outputs, params); + assertEquals("This string contains step", result); + } + public void testGetInputsFromPreviousSteps() { WorkflowData currentNodeInputs = new WorkflowData( Map.ofEntries(