Skip to content

Commit

Permalink
Added user permission and new tests
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Apr 25, 2024
1 parent 00f912b commit 025f21b
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Abstract class to handle search request.
Expand Down Expand Up @@ -85,7 +84,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder));
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -80,7 +82,8 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<GetW
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND));
} else {
// Remove any secured field from response
Template template = encryptorUtils.redactTemplateSecuredFields(Template.parse(response.getSourceAsString()));
User user = ParseUtils.getUserContext(client);
Template template = encryptorUtils.redactTemplateSecuredFields(user, Template.parse(response.getSourceAsString()));
listener.onResponse(new GetWorkflowResponse(template));
}
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflow states
*/
Expand All @@ -45,8 +50,10 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflow states in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflow states in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflows created
*/
Expand All @@ -45,8 +50,11 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflows in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflows in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand Down Expand Up @@ -201,10 +202,11 @@ String decrypt(final String encryptedCredential) {
// TODO : Improve redactTemplateCredentials to redact different fields
/**
* Removes the credential fields from a template
* @param user User
* @param template the template
* @return the redacted template
*/
public Template redactTemplateSecuredFields(Template template) {
public Template redactTemplateSecuredFields(User user, Template template) {
Map<String, Workflow> processedWorkflows = new HashMap<>();

for (Map.Entry<String, Workflow> entry : template.workflows().entrySet()) {
Expand All @@ -228,6 +230,10 @@ public Template redactTemplateSecuredFields(Template template) {
processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges()));
}

if (ParseUtils.isAdmin(user)) {
return new Template.Builder(template).workflows(processedWorkflows).build();
}

return new Template.Builder(template).user(null).workflows(processedWorkflows).build();
}

Expand Down
30 changes: 12 additions & 18 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.FileNotFoundException;
import java.io.IOException;
Expand All @@ -47,8 +46,6 @@
import jakarta.json.bind.JsonbBuilder;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
* Utility methods for Template parsing
Expand Down Expand Up @@ -113,6 +110,18 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* 'all_access' role users are treated as admins.
* @param user of the current role
* @return boolean if the role is admin
*/
public static boolean isAdmin(User user) {
if (user == null) {
return false;
}
return user.getRoles().contains("all_access");
}

/**
* Builds an XContent object representing a map of String keys to Object values.
*
Expand All @@ -132,21 +141,6 @@ public static void buildStringToObjectMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param llm LLMSpec
* @throws IOException on a build failure
*/
public static void buildLLMMap(XContentBuilder xContentBuilder, LLMSpec llm) throws IOException {
String modelId = llm.getModelId();
Map<String, String> parameters = llm.getParameters();
xContentBuilder.field(MODEL_ID, modelId);
xContentBuilder.field(PARAMETERS_FIELD);
buildStringToStringMap(xContentBuilder, parameters);
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
package org.opensearch.flowframework.util;

import org.apache.commons.lang3.ArrayUtils;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.Strings;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;

Expand All @@ -36,18 +36,21 @@ private RestHandlerUtils() {}
/**
* Creates a source context and include/exclude information to be shared based on the user
*
* @param request the REST request
* @param user User
* @param searchSourceBuilder the search request source builder
* @return modified sources
*/
public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) {
public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder searchSourceBuilder) {
// TODO
// 1. check if the request came from dashboard and exclude UI_METADATA
if (searchSourceBuilder.fetchSource() != null) {
String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), DASHBOARD_EXCLUDES);
return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray);
} else {
// When user does not set the _source field in search api request, searchSourceBuilder.fetchSource becomes null
if (ParseUtils.isAdmin(user)) {
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, new String[] { PATH_TO_CREDENTIAL_FIELD });
}
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -71,6 +72,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -73,6 +74,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.TestHelpers;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand All @@ -26,6 +27,7 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -199,8 +201,10 @@ public void testRedactTemplateCredential() {
WorkflowNode node = testTemplate.workflows().get("provision").nodes().get(0);
assertNotNull(node.userInputs().get(CREDENTIAL_FIELD));

User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList());

// Redact template with credential field
Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(testTemplate);
Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate);

// Validate the credential field has been removed
WorkflowNode redactedNode = redactedTemplate.workflows().get("provision").nodes().get(0);
Expand All @@ -211,10 +215,25 @@ public void testRedactTemplateUserField() {
// Confirm user is present in the non-redacted template
assertNotNull(testTemplate.getUser());

User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
// Redact template with user field
Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(testTemplate);
Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate);

// Validate the user field has been removed
assertNull(redactedTemplate.getUser());
}

public void testAdminUserTemplate() {
// Confirm user is present in the non-redacted template
assertNotNull(testTemplate.getUser());

List<String> roles = new ArrayList<>();
roles.add("all_access");

User user = new User("admin", roles, roles, Collections.emptyList());

// Redact template with user field
Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate);
assertNotNull(redactedTemplate.getUser());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.util;

import org.opensearch.commons.authuser.User;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.test.OpenSearchTestCase;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class RestHandlerUtilsTests extends OpenSearchTestCase {

public void testGetSourceContextFromClientWithDashboardExcludes() {
SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder();
testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[] { "b" });
User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder);
assertEquals(sourceContext.excludes().length, 4);
}

public void testGetSourceContextFromClientWithExcludes() {
SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder();
User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder);
assertEquals(sourceContext.excludes().length, 2);
}

public void testGetSourceContextAdminUser() {
SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder();
List<String> roles = new ArrayList<>();
roles.add("all_access");

User user = new User("admin", roles, roles, Collections.emptyList());
FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder);
assertEquals(sourceContext.excludes().length, 1);
}

}

0 comments on commit 025f21b

Please sign in to comment.