Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Hides user and credential field from search response #681

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
- Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639))
- Hide user and credential field from search response ([#680](https://github.com/opensearch-project/flow-framework/pull/680))

### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,14 @@
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.action.RestResponseListener;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Abstract class to handle search request.
Expand Down Expand Up @@ -89,23 +84,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder));
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout());

// Apply credential filter when searching templates
if (index.equals(GLOBAL_CONTEXT_INDEX)) {
searchSourceBuilder.scriptField(
"filter",
new Script(
ScriptType.INLINE,
"painless",
"def filteredSource = new HashMap(params._source); def workflows = filteredSource.get(\"workflows\"); if (workflows != null) { def provision = workflows.get(\"provision\"); if (provision != null) { def nodes = provision.get(\"nodes\"); if (nodes != null) { for (node in nodes) { def userInputs = node.get(\"user_inputs\"); if (userInputs != null) { userInputs.remove(\"credential\"); } } } } } return filteredSource;",
Collections.emptyMap()
)
);
}

SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return channel -> client.execute(actionType, searchRequest, search(channel));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -79,8 +81,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<GetW
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND));
} else {
// Remove any credential from response
Template template = encryptorUtils.redactTemplateCredentials(Template.parse(response.getSourceAsString()));
// Remove any secured field from response
User user = ParseUtils.getUserContext(client);
Template template = encryptorUtils.redactTemplateSecuredFields(user, Template.parse(response.getSourceAsString()));
listener.onResponse(new GetWorkflowResponse(template));
}
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflow states
*/
Expand All @@ -45,8 +50,10 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflow states in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflow states in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflows created
*/
Expand All @@ -45,8 +50,11 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflows in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflows in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand Down Expand Up @@ -122,7 +123,7 @@ public Template decryptTemplateCredentials(Template template) {
/**
* Applies the given cipher function on template credentials
* @param template the template to process
* @param cipher the encryption/decryption function to apply on credential values
* @param cipherFunction the encryption/decryption function to apply on credential values
* @return template with encrypted credentials
*/
private Template processTemplateCredentials(Template template, Function<String, String> cipherFunction) {
Expand Down Expand Up @@ -201,11 +202,13 @@ String decrypt(final String encryptedCredential) {
// TODO : Improve redactTemplateCredentials to redact different fields
/**
* Removes the credential fields from a template
* @param user User
* @param template the template
* @return the redacted template
*/
public Template redactTemplateCredentials(Template template) {
public Template redactTemplateSecuredFields(User user, Template template) {
Map<String, Workflow> processedWorkflows = new HashMap<>();

for (Map.Entry<String, Workflow> entry : template.workflows().entrySet()) {

List<WorkflowNode> processedNodes = new ArrayList<>();
Expand All @@ -227,7 +230,11 @@ public Template redactTemplateCredentials(Template template) {
processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges()));
}

return new Template.Builder(template).workflows(processedWorkflows).build();
if (ParseUtils.isAdmin(user)) {
return new Template.Builder(template).workflows(processedWorkflows).build();
}

return new Template.Builder(template).user(null).workflows(processedWorkflows).build();
}

/**
Expand Down
30 changes: 12 additions & 18 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.FileNotFoundException;
import java.io.IOException;
Expand All @@ -47,8 +46,6 @@
import jakarta.json.bind.JsonbBuilder;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
* Utility methods for Template parsing
Expand Down Expand Up @@ -113,6 +110,18 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* 'all_access' role users are treated as admins.
* @param user of the current role
* @return boolean if the role is admin
*/
public static boolean isAdmin(User user) {
if (user == null) {
return false;
}
return user.getRoles().contains("all_access");
}

/**
* Builds an XContent object representing a map of String keys to Object values.
*
Expand All @@ -132,21 +141,6 @@ public static void buildStringToObjectMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param llm LLMSpec
* @throws IOException on a build failure
*/
public static void buildLLMMap(XContentBuilder xContentBuilder, LLMSpec llm) throws IOException {
String modelId = llm.getModelId();
Map<String, String> parameters = llm.getParameters();
xContentBuilder.field(MODEL_ID, modelId);
xContentBuilder.field(PARAMETERS_FIELD);
buildStringToStringMap(xContentBuilder, parameters);
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
package org.opensearch.flowframework.util;

import org.apache.commons.lang3.ArrayUtils;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.Strings;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;

Expand All @@ -19,26 +20,36 @@
*/
public class RestHandlerUtils {

/** Path to credential field **/
private static final String PATH_TO_CREDENTIAL_FIELD = "workflows.provision.nodes.user_inputs.credential";

/** Fields that need to be excluded from the Search Response*/
public static final String[] USER_EXCLUDE = new String[] { CommonValue.USER_FIELD, CommonValue.UI_METADATA_FIELD };
private static final String[] DASHBOARD_EXCLUDES = new String[] {
CommonValue.USER_FIELD,
CommonValue.UI_METADATA_FIELD,
PATH_TO_CREDENTIAL_FIELD };

private static final String[] EXCLUDES = new String[] { CommonValue.USER_FIELD, PATH_TO_CREDENTIAL_FIELD };

private RestHandlerUtils() {}

/**
* Creates a source context and include/exclude information to be shared based on the user
*
* @param request the REST request
* @param user User
* @param searchSourceBuilder the search request source builder
* @return modified sources
*/
public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) {
// TODO
// 1. check if the request came from dashboard and exclude UI_METADATA
public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder searchSourceBuilder) {
if (searchSourceBuilder.fetchSource() != null) {
String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), USER_EXCLUDE);
String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), DASHBOARD_EXCLUDES);
return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray);
} else {
return null;
// When user does not set the _source field in search api request, searchSourceBuilder.fetchSource becomes null
if (ParseUtils.isAdmin(user)) {
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, new String[] { PATH_TO_CREDENTIAL_FIELD });
}
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -71,6 +72,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -34,12 +35,17 @@ public class SearchWorkflowTransportActionTests extends OpenSearchTestCase {
private SearchWorkflowTransportAction searchWorkflowTransportAction;
private Client client;
private ThreadPool threadPool;
ThreadContext threadContext;

@Override
public void setUp() throws Exception {
super.setUp();
this.client = mock(Client.class);

this.threadPool = mock(ThreadPool.class);
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
this.searchWorkflowTransportAction = new SearchWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
Expand Down Expand Up @@ -73,6 +79,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Loading
Loading