Skip to content

Commit

Permalink
Add validations from appsec (#562)
Browse files Browse the repository at this point in the history
* add validations from appsec

Signed-off-by: HenryL27 <[email protected]>
Co-authored-by: Heemin Kim <[email protected]>
  • Loading branch information
2 people authored and martin-gaievski committed Feb 6, 2024
1 parent a90784c commit fa8d80c
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
- Fix Flaky test reported in #384 ([#559](https://github.com/opensearch-project/neural-search/pull/559))
- Add validations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -145,7 +146,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);
}

@Override
Expand All @@ -159,7 +160,10 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchReques
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
Expand All @@ -37,6 +38,7 @@ public class RerankProcessorFactory implements Processor.Factory<SearchResponseP
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final ClusterService clusterService;

@Override
public SearchResponseProcessor create(
Expand All @@ -49,7 +51,12 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
Expand Down Expand Up @@ -109,22 +116,23 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
String tag,
final ClusterService clusterService
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
fetchers.add(new QueryContextSourceFetcher(clusterService));
}
return fetchers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.search.SearchHit;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -87,14 +90,25 @@ public String getName() {
* @param config configuration object grabbed from parsed API request. Should be a list of strings
* @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed
*/
public static DocumentContextSourceFetcher create(Object config) {
public static DocumentContextSourceFetcher create(Object config, ClusterService clusterService) {
if (!(config instanceof List)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME));
}
List<?> fields = (List<?>) config;
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME));
}
if (fields.size() > RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings())) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
NAME,
RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()),
RERANKER_MAX_DOC_FIELDS.getKey()
)
);
}
List<String> fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList());
return new DocumentContextSourceFetcher(fieldsAsStrings);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,36 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
* Context Source Fetcher that gets context from the rerank query ext.
*/
@Log4j2
@AllArgsConstructor
public class QueryContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "query_context";
public static final String QUERY_TEXT_FIELD = "query_text";
public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";

public static final Integer MAX_QUERY_PATH_STRLEN = 1000;

private final ClusterService clusterService;

@Override
public void fetchContext(
final SearchRequest searchRequest,
Expand Down Expand Up @@ -65,6 +76,7 @@ public void fetchContext(
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
// Case "query_text_path": ser/de the query into a map and then find the text at the path specified
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
validatePath(path);
Map<String, Object> map = requestToMap(searchRequest);
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
Expand Down Expand Up @@ -107,4 +119,32 @@ private static Map<String, Object> requestToMap(final SearchRequest request) thr
Map<String, Object> map = parser.map();
return map;
}

private void validatePath(final String path) throws IllegalArgumentException {
if (path == null || path.isEmpty()) {
return;
}
if (path.length() > MAX_QUERY_PATH_STRLEN) {
log.error(String.format(Locale.ROOT, "invalid %s due to too many characters: %s", QUERY_TEXT_PATH_FIELD, path));
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d characters",
QUERY_TEXT_PATH_FIELD,
MAX_QUERY_PATH_STRLEN
)
);
}
if (path.split("\\.").length > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())) {
log.error(String.format(Locale.ROOT, "invalid %s due to too many nested fields: %s", QUERY_TEXT_PATH_FIELD, path));
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields",
QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ public final class NeuralSearchSettings {
false,
Setting.Property.NodeScope
);

/**
* Limits the number of document fields that can be passed to the reranker.
*/
public static final Setting<Integer> RERANKER_MAX_DOC_FIELDS = Setting.intSetting(
"plugins.neural_search.reranker_max_document_fields",
50,
Setting.Property.NodeScope
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand All @@ -15,6 +18,8 @@
import org.junit.Before;
import org.mockito.Mock;
import org.opensearch.OpenSearchParseException;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand All @@ -37,11 +42,16 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase {
@Mock
private PipelineContext pipelineContext;

@Mock
private ClusterService clusterService;

@Before
public void setup() {
clusterService = mock(ClusterService.class);
pipelineContext = mock(PipelineContext.class);
clientAccessor = mock(MLCommonsClientAccessor.class);
factory = new RerankProcessorFactory(clientAccessor);
factory = new RerankProcessorFactory(clientAccessor, clusterService);
doReturn(Settings.EMPTY).when(clusterService).getSettings();
}

public void testRerankProcessorFactory_whenEmptyConfig_thenFail() {
Expand Down Expand Up @@ -187,4 +197,26 @@ public void testCrossEncoder_whenEmptyContextDocField_thenFail() {
);
}

public void testCrossEncoder_whenTooManyDocFields_thenFail() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, Collections.nCopies(75, "field")))
)
);
assertThrows(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
DocumentContextSourceFetcher.NAME,
RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()),
RERANKER_MAX_DOC_FIELDS.getKey()
),
IllegalArgumentException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponse.Clusters;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand All @@ -49,6 +52,8 @@
import org.opensearch.search.pipeline.Processor.PipelineContext;
import org.opensearch.test.OpenSearchTestCase;

import lombok.SneakyThrows;

public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {

@Mock
Expand All @@ -65,14 +70,18 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {
@Mock
private PipelineProcessingContext ppctx;

@Mock
private ClusterService clusterService;

private RerankProcessorFactory factory;

private MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
factory = new RerankProcessorFactory(mlCommonsClientAccessor);
doReturn(Settings.EMPTY).when(clusterService).getSettings();
factory = new RerankProcessorFactory(mlCommonsClientAccessor, clusterService);
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
Expand Down Expand Up @@ -223,6 +232,51 @@ public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IO
.equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field"));
}

@SneakyThrows
public void testRerankContext_whenQueryTextPathIsExceeedinglyManyCharacters_thenFail() {
// "eighteencharacters" * 60 = 1080 character string > max len of 1024
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "eighteencharacters".repeat(60)));
setupSearchResults();
@SuppressWarnings("unchecked")
ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d characters",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
QueryContextSourceFetcher.MAX_QUERY_PATH_STRLEN
)
));
}

@SneakyThrows
public void textRerankContext_whenQueryTextPathIsExceeedinglyDeeplyNested_thenFail() {
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.w.x.y.z"));
setupSearchResults();
@SuppressWarnings("unchecked")
ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())
)
));
}

public void testRescoreSearchResponse_HappyPath() throws IOException {
setupSimilarityRescoring();
setupSearchResults();
Expand Down

0 comments on commit fa8d80c

Please sign in to comment.