diff --git a/CHANGELOG.md b/CHANGELOG.md index 85524f4bf..674a369c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533)) - Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541)) - Fix Flaky test reported in #384 ([#559](https://github.com/opensearch-project/neural-search/pull/559)) +- Add validations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562)) ### Infrastructure - BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515)) - Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535)) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index a77118f43..0182ff4d3 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.plugin; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS; import java.util.Arrays; import java.util.Collection; @@ -145,7 +146,7 @@ public Map> getSettings() { - return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED); + return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS); } @Override @@ -159,7 +160,10 @@ public Map> getResponseProcessors( Parameters parameters ) { - return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor)); + return Map.of( + RerankProcessor.TYPE, + new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index b02666855..9b9715df5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -11,6 +11,7 @@ import java.util.Set; import java.util.StringJoiner; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; @@ -37,6 +38,7 @@ public class RerankProcessorFactory implements Processor.Factory contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag); + List contextFetchers = ContextFetcherFactory.createFetchers( + config, + includeQueryContextFetcher, + tag, + clusterService + ); switch (type) { case ML_OPENSEARCH: Map rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); @@ -109,7 +116,8 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { public static List createFetchers( Map config, boolean includeQueryContextFetcher, - String tag + String tag, + final ClusterService clusterService ) { Map contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); List fetchers = new ArrayList<>(); @@ -117,14 +125,14 @@ public static List createFetchers( Object cfg = contextConfig.get(key); switch (key) { case DocumentContextSourceFetcher.NAME: - fetchers.add(DocumentContextSourceFetcher.create(cfg)); + fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService)); break; default: throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); } } if (includeQueryContextFetcher) { - fetchers.add(new QueryContextSourceFetcher()); + fetchers.add(new QueryContextSourceFetcher(clusterService)); } return fetchers; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java index 857c1dd46..a2f69e44f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java @@ -13,10 +13,13 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.ObjectPath; import org.opensearch.search.SearchHit; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS; + import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -87,7 +90,7 @@ public String getName() { * @param config configuration object grabbed from parsed API request. Should be a list of strings * @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed */ - public static DocumentContextSourceFetcher create(Object config) { + public static DocumentContextSourceFetcher create(Object config, ClusterService clusterService) { if (!(config instanceof List)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME)); } @@ -95,6 +98,17 @@ public static DocumentContextSourceFetcher create(Object config) { if (fields.size() == 0) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME)); } + if (fields.size() > RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings())) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s must not contain more than %d fields. Configure by setting %s", + NAME, + RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()), + RERANKER_MAX_DOC_FIELDS.getKey() + ) + ); + } List fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList()); return new DocumentContextSourceFetcher(fieldsAsStrings); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java index d7463bcd1..fa068ee88 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java @@ -14,6 +14,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -21,18 +22,28 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MapperService; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.search.SearchExtBuilder; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + /** * Context Source Fetcher that gets context from the rerank query ext. */ +@Log4j2 +@AllArgsConstructor public class QueryContextSourceFetcher implements ContextSourceFetcher { public static final String NAME = "query_context"; public static final String QUERY_TEXT_FIELD = "query_text"; public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; + public static final Integer MAX_QUERY_PATH_STRLEN = 1000; + + private final ClusterService clusterService; + @Override public void fetchContext( final SearchRequest searchRequest, @@ -65,6 +76,7 @@ public void fetchContext( } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { // Case "query_text_path": ser/de the query into a map and then find the text at the path specified String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); + validatePath(path); Map map = requestToMap(searchRequest); // Get the text at the path Object queryText = ObjectPath.eval(path, map); @@ -107,4 +119,32 @@ private static Map requestToMap(final SearchRequest request) thr Map map = parser.map(); return map; } + + private void validatePath(final String path) throws IllegalArgumentException { + if (path == null || path.isEmpty()) { + return; + } + if (path.length() > MAX_QUERY_PATH_STRLEN) { + log.error(String.format(Locale.ROOT, "invalid %s due to too many characters: %s", QUERY_TEXT_PATH_FIELD, path)); + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s exceeded the maximum path length of %d characters", + QUERY_TEXT_PATH_FIELD, + MAX_QUERY_PATH_STRLEN + ) + ); + } + if (path.split("\\.").length > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())) { + log.error(String.format(Locale.ROOT, "invalid %s due to too many nested fields: %s", QUERY_TEXT_PATH_FIELD, path)); + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s exceeded the maximum path length of %d nested fields", + QUERY_TEXT_PATH_FIELD, + MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings()) + ) + ); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java index bf887830d..d6d3233ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java +++ b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java @@ -25,4 +25,13 @@ public final class NeuralSearchSettings { false, Setting.Property.NodeScope ); + + /** + * Limits the number of document fields that can be passed to the reranker. + */ + public static final Setting RERANKER_MAX_DOC_FIELDS = Setting.intSetting( + "plugins.neural_search.reranker_max_document_fields", + 50, + Setting.Property.NodeScope + ); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index ea37b2afb..c464f2826 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -4,9 +4,12 @@ */ package org.opensearch.neuralsearch.processor.factory; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -15,6 +18,8 @@ import org.junit.Before; import org.mockito.Mock; import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -37,11 +42,16 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase { @Mock private PipelineContext pipelineContext; + @Mock + private ClusterService clusterService; + @Before public void setup() { + clusterService = mock(ClusterService.class); pipelineContext = mock(PipelineContext.class); clientAccessor = mock(MLCommonsClientAccessor.class); - factory = new RerankProcessorFactory(clientAccessor); + factory = new RerankProcessorFactory(clientAccessor, clusterService); + doReturn(Settings.EMPTY).when(clusterService).getSettings(); } public void testRerankProcessorFactory_whenEmptyConfig_thenFail() { @@ -187,4 +197,26 @@ public void testCrossEncoder_whenEmptyContextDocField_thenFail() { ); } + public void testCrossEncoder_whenTooManyDocFields_thenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, Collections.nCopies(75, "field"))) + ) + ); + assertThrows( + String.format( + Locale.ROOT, + "%s must not contain more than %d fields. Configure by setting %s", + DocumentContextSourceFetcher.NAME, + RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()), + RERANKER_MAX_DOC_FIELDS.getKey() + ), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 50d0cf2bc..dbd1c2bd6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -28,13 +28,16 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.common.document.DocumentField; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.MapperService; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; @@ -49,6 +52,8 @@ import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.test.OpenSearchTestCase; +import lombok.SneakyThrows; + public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Mock @@ -65,6 +70,9 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Mock private PipelineProcessingContext ppctx; + @Mock + private ClusterService clusterService; + private RerankProcessorFactory factory; private MLOpenSearchRerankProcessor processor; @@ -72,7 +80,8 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - factory = new RerankProcessorFactory(mlCommonsClientAccessor); + doReturn(Settings.EMPTY).when(clusterService).getSettings(); + factory = new RerankProcessorFactory(mlCommonsClientAccessor, clusterService); Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), @@ -223,6 +232,51 @@ public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IO .equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field")); } + @SneakyThrows + public void testRerankContext_whenQueryTextPathIsExceeedinglyManyCharacters_thenFail() { + // "eighteencharacters" * 60 = 1080 character string > max len of 1024 + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "eighteencharacters".repeat(60))); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + String.format( + Locale.ROOT, + "%s exceeded the maximum path length of %d characters", + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, + QueryContextSourceFetcher.MAX_QUERY_PATH_STRLEN + ) + )); + } + + @SneakyThrows + public void textRerankContext_whenQueryTextPathIsExceeedinglyDeeplyNested_thenFail() { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.w.x.y.z")); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + String.format( + Locale.ROOT, + "%s exceeded the maximum path length of %d nested fields", + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, + MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings()) + ) + )); + } + public void testRescoreSearchResponse_HappyPath() throws IOException { setupSimilarityRescoring(); setupSearchResults();