Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding guardrails to default use case params #658

Merged
merged 8 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.13...2.x)
### Features
### Enhancements
- Adding guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658))
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
- Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,18 @@ private CommonValue() {}
public static final String RESOURCE_ID = "resource_id";
/** The field name for the opensearch-ml plugin */
public static final String OPENSEARCH_ML = "opensearch-ml";

/*
* Constants assoicated with substitution / default templates
*/
/** The field name for connector credential key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_KEY = "create_connector.credential.key";
/** The field name for connector credential access key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY = "create_connector.credential.access_key";
/** The field name for connector credential secret key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY = "create_connector.credential.secret_key";
/** The field name for connector credential session token substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN = "create_connector.credential.session_token";
/** The field name for ingest pipeline model ID substitution */
public static final String CREATE_INGEST_PIPELINE_MODEL_ID = "create_ingest_pipeline.model_id";
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN;
import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID;

/**
* Enum encapsulating the different default use cases and templates we have stored
*/
Expand All @@ -22,94 +32,119 @@
OPEN_AI_EMBEDDING_MODEL_DEPLOY(
"open_ai_embedding_model_deploy",
"defaults/openai-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-template.json"
"substitutionTemplates/deploy-remote-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for Cohere embedding model */
COHERE_EMBEDDING_MODEL_DEPLOY(
"cohere_embedding_model_deploy",
"defaults/cohere-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-extra-params-template.json"
"substitutionTemplates/deploy-remote-model-extra-params-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for Bedrock Titan embedding model */
BEDROCK_TITAN_EMBEDDING_MODEL_DEPLOY(
"bedrock_titan_embedding_model_deploy",
"defaults/bedrock-titan-embedding-defaults.json",
"substitutionTemplates/deploy-remote-bedrock-model-template.json"
"substitutionTemplates/deploy-remote-bedrock-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for Bedrock Titan multimodal embedding model */
BEDROCK_TITAN_MULTIMODAL_MODEL_DEPLOY(
"bedrock_titan_multimodal_model_deploy",
"defaults/bedrock-titan-multimodal-defaults.json",
"substitutionTemplates/deploy-remote-bedrock-model-template.json"
"substitutionTemplates/deploy-remote-bedrock-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for Cohere chat model */
COHERE_CHAT_MODEL_DEPLOY(
"cohere_chat_model_deploy",
"defaults/cohere-chat-defaults.json",
"substitutionTemplates/deploy-remote-model-chat-template.json"
"substitutionTemplates/deploy-remote-model-chat-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for OpenAI chat model */
OPENAI_CHAT_MODEL_DEPLOY(
"openai_chat_model_deploy",
"defaults/openai-chat-defaults.json",
"substitutionTemplates/deploy-remote-model-chat-template.json"
"substitutionTemplates/deploy-remote-model-chat-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for local neural sparse model and ingest pipeline*/
LOCAL_NEURAL_SPARSE_SEARCH_BI_ENCODER(
"local_neural_sparse_search_bi_encoder",
"defaults/local-sparse-search-biencoder-defaults.json",
"substitutionTemplates/neural-sparse-local-biencoder-template.json"
"substitutionTemplates/neural-sparse-local-biencoder-template.json",
Collections.emptyList()
),
/** defaults file and substitution ready template for semantic search, no model creation*/
SEMANTIC_SEARCH("semantic_search", "defaults/semantic-search-defaults.json", "substitutionTemplates/semantic-search-template.json"),
SEMANTIC_SEARCH(
"semantic_search",
"defaults/semantic-search-defaults.json",
"substitutionTemplates/semantic-search-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
),
/** defaults file and substitution ready template for multimodal search, no model creation*/
MULTI_MODAL_SEARCH(
"multimodal_search",
"defaults/multi-modal-search-defaults.json",
"substitutionTemplates/multi-modal-search-template.json"
"substitutionTemplates/multi-modal-search-template.json",
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for multimodal search, no model creation*/
MULTI_MODAL_SEARCH_WITH_BEDROCK_TITAN(
"multimodal_search_with_bedrock_titan",
"defaults/multimodal-search-bedrock-titan-defaults.json",
"substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json"
"substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for semantic search with query enricher processor attached, no model creation*/
SEMANTIC_SEARCH_WITH_QUERY_ENRICHER(
"semantic_search_with_query_enricher",
"defaults/semantic-search-query-enricher-defaults.json",
"substitutionTemplates/semantic-search-with-query-enricher-template.json"
"substitutionTemplates/semantic-search-with-query-enricher-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for semantic search with cohere embedding model*/
SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING(
"semantic_search_with_cohere_embedding",
"defaults/cohere-embedding-semantic-search-defaults.json",
"substitutionTemplates/semantic-search-with-model-template.json"
"substitutionTemplates/semantic-search-with-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for semantic search with query enricher processor attached and cohere embedding model*/
SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING_AND_QUERY_ENRICHER(
"semantic_search_with_cohere_embedding_query_enricher",
"defaults/cohere-embedding-semantic-search-with-query-enricher-defaults.json",
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json"
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for hybrid search, no model creation*/
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"),
HYBRID_SEARCH(
"hybrid_search",
"defaults/hybrid-search-defaults.json",
"substitutionTemplates/hybrid-search-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for conversational search with cohere chat model*/
CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY(
"conversational_search_with_llm_deploy",
"defaults/conversational-search-defaults.json",
"substitutionTemplates/conversational-search-with-cohere-model-template.json"
"substitutionTemplates/conversational-search-with-cohere-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
);

private final String useCaseName;
private final String defaultsFile;
private final String substitutionReadyFile;
private final List<String> requiredParams;
private static final Logger logger = LogManager.getLogger(DefaultUseCases.class);

DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) {
DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile, List<String> requiredParams) {
this.useCaseName = useCaseName;
this.defaultsFile = defaultsFile;
this.substitutionReadyFile = substitutionReadyFile;
this.requiredParams = requiredParams;
}

/**
Expand All @@ -136,6 +171,14 @@
return substitutionReadyFile;
}

/**
* Returns the required params for the given enum Constant
* @return the required params of the given useCase
*/
public List<String> getRequiredParams() {
return requiredParams;
}

/**
* Gets the defaultsFile based on the given use case.
* @param useCaseName name of the given use case
Expand Down Expand Up @@ -171,4 +214,22 @@
logger.error("Unable to find substitution ready file for use case: {}", useCaseName);
throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
}

/**
* Gets the required parameters based on the given use case
* @param useCaseName name of the given use case
* @return the list of required params
* @throws FlowFrameworkException if the use case doesn't exist in enum
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
*/
public static List<String> getRequiredParamsByUseCaseName(String useCaseName) throws FlowFrameworkException {
if (useCaseName != null && !useCaseName.isEmpty()) {
for (DefaultUseCases useCase : values()) {
if (useCase.getUseCaseName().equals(useCaseName)) {
return new ArrayList<String>(useCase.getRequiredParams());
}
}
}
logger.error("Unable to find required parameters for use case: {}", useCaseName);
throw new FlowFrameworkException("Unable to find required parameters for use case: " + useCaseName, RestStatus.BAD_REQUEST);

Check warning on line 233 in src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java#L232-L233

Added lines #L232 - L233 were not covered by tests
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -126,12 +127,31 @@
);
String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase);
useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath);

if (request.hasContent()) {
List<String> requiredParams = DefaultUseCases.getRequiredParamsByUseCaseName(useCase);

if (request.hasContent() == false) {
if (requiredParams.size() != 0) {
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
throw new FlowFrameworkException(
"Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(),

Check warning on line 135 in src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java#L134-L135

Added lines #L134 - L135 were not covered by tests
RestStatus.BAD_REQUEST
);
}
} else {
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Map<String, Object> userDefaults = ParseUtils.parseStringToObjectMap(parser);

joshpalis marked this conversation as resolved.
Show resolved Hide resolved
// Validate user defaults key set
Set<String> userDefaultKeys = userDefaults.keySet();
if (!userDefaultKeys.containsAll(requiredParams)) {
requiredParams.removeAll(userDefaultKeys);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
throw new FlowFrameworkException(
"Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(),

Check warning on line 150 in src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java#L148-L150

Added lines #L148 - L150 were not covered by tests
RestStatus.BAD_REQUEST
);
}

owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
// updates the default params with anything user has given that matches
for (Map.Entry<String, Object> userDefaultsEntry : userDefaults.entrySet()) {
String key = userDefaultsEntry.getKey();
Expand All @@ -141,13 +161,16 @@
}
}
} catch (Exception ex) {
RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex);
String errorMessage =
"failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, status);
if (ex instanceof FlowFrameworkException) {
throw (FlowFrameworkException) ex;

Check warning on line 165 in src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java#L165

Added line #L165 was not covered by tests
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
} else {
RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex);
String errorMessage =

Check warning on line 168 in src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java#L168

Added line #L168 was not covered by tests
"failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, status);

Check warning on line 171 in src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java#L170-L171

Added lines #L170 - L171 were not covered by tests
}
}

}

useCaseTemplateFileInStringFormat = (String) ParseUtils.conditionallySubstitute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,24 @@ protected Response createWorkflow(RestClient client, Template template) throws E
* Helper method to invoke the Create Workflow Rest Action without validation
* @param client the rest client
* @param useCase the usecase to create
* @param the required params
* @throws Exception if the request fails
* @return a rest response
*/
protected Response createWorkflowWithUseCase(RestClient client, String useCase) throws Exception {
protected Response createWorkflowWithUseCase(RestClient client, String useCase, List<String> params) throws Exception {

StringBuilder sb = new StringBuilder();
for (String param : params) {
sb.append("\"" + param + "\" : \"\"").append(",");
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
}
if (params.size() != 0) sb.deleteCharAt(sb.length() - 1);
joshpalis marked this conversation as resolved.
Show resolved Hide resolved

return TestHelpers.makeRequest(
client,
"POST",
WORKFLOW_URI + "?validation=off&use_case=" + useCase,
Collections.emptyMap(),
"{}",
"{" + sb.toString() + "}",
null
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;

Expand Down Expand Up @@ -406,7 +408,7 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
public void testDefaultCohereUseCase() throws Exception {

// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy");
Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY));
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down Expand Up @@ -442,8 +444,18 @@ public void testDefaultCohereUseCase() throws Exception {
}

public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Exception {
// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCase(client(), "semantic_search");
// Hit Create Workflow API with original template without required params
ResponseException exception = expectThrows(
ResponseException.class,
() -> createWorkflowWithUseCase(client(), "semantic_search", Collections.emptyList())
);
assertTrue(
exception.getMessage()
.contains("Missing the following required parameters for use case [semantic_search] : [create_ingest_pipeline.model_id]")
);

// Pass in required params
Response response = createWorkflowWithUseCase(client(), "semantic_search", List.of(CREATE_INGEST_PIPELINE_MODEL_ID));
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down Expand Up @@ -483,7 +495,11 @@ public void testAllDefaultUseCasesCreation() throws Exception {
.collect(Collectors.toSet());

for (String useCaseName : allUseCaseNames) {
Response response = createWorkflowWithUseCase(client(), useCaseName);
Response response = createWorkflowWithUseCase(
client(),
useCaseName,
DefaultUseCases.getRequiredParamsByUseCaseName(useCaseName)
);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down
Loading
Loading