From 4ffc2b7135e9cbf4d9d46c0e27874ec4b29fc23d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 28 Dec 2023 03:32:11 +0000 Subject: [PATCH] feat: search index tool (#61) * add search index tool Signed-off-by: yuye-aws * run spotless apply Signed-off-by: yuye-aws * remove unncessary string util operation Signed-off-by: yuye-aws * add test cases Signed-off-by: yuye-aws * spotless apply Signed-off-by: yuye-aws * update tool description and add model group search Signed-off-by: yuye-aws --------- Signed-off-by: yuye-aws (cherry picked from commit 2e053305ddf806e29c5913f31d4584dba1d8ad77) Signed-off-by: github-actions[bot] --- .../java/org/opensearch/agent/ToolPlugin.java | 2 + .../agent/tools/AbstractRetrieverTool.java | 18 +- .../agent/tools/SearchIndexTool.java | 172 ++++++++++++++++++ .../tools/AbstractRetrieverToolTests.java | 4 +- .../agent/tools/SearchIndexToolTests.java | 169 +++++++++++++++++ .../tools/retrieval_tool_search_response.json | 2 +- 6 files changed, 357 insertions(+), 10 deletions(-) create mode 100644 src/main/java/org/opensearch/agent/tools/SearchIndexTool.java create mode 100644 src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 5ac1ce57..a411acbb 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -12,6 +12,7 @@ import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; +import org.opensearch.agent.tools.SearchIndexTool; import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.agent.tools.VisualizationsTool; import org.opensearch.client.Client; @@ -60,6 +61,7 @@ public Collection createComponents( VisualizationsTool.Factory.getInstance().init(client); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); + SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); return Collections.emptyList(); } diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index dba48070..b2a0860c 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -66,6 +66,15 @@ protected AbstractRetrieverTool( protected abstract String getQueryBody(String queryText); + public static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + return docContent; + } + private SearchRequest buildSearchRequest(Map parameters) throws IOException { String question = parameters.get(INPUT_FIELD); if (StringUtils.isBlank(question)) { @@ -98,13 +107,8 @@ public void run(Map parameters, ActionListener listener) if (hits != null && hits.length > 0) { StringBuilder contextBuilder = new StringBuilder(); - for (int i = 0; i < hits.length; i++) { - SearchHit hit = hits[i]; - Map docContent = new HashMap<>(); - docContent.put("_index", hit.getIndex()); - docContent.put("_id", hit.getId()); - docContent.put("_score", hit.getScore()); - docContent.put("_source", hit.getSourceAsMap()); + for (SearchHit hit : hits) { + Map docContent = processResponse(hit); contextBuilder.append(gson.toJson(docContent)).append("\n"); } listener.onResponse((T) contextBuilder.toString()); diff --git a/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java b/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java new file mode 100644 index 00000000..5dd10759 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.*; + +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.gson.JsonObject; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Getter +@Setter +@Log4j2 +@ToolAnnotation(SearchIndexTool.TYPE) +public class SearchIndexTool implements Tool { + + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String QUERY_FIELD = "query"; + + public static final String TYPE = "SearchIndexTool"; + private static final String DEFAULT_DESCRIPTION = + "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query."; + + private String name = TYPE; + + private String description = DEFAULT_DESCRIPTION; + + private Client client; + + private NamedXContentRegistry xContentRegistry; + + public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; + } + + @Override + public void run(Map parameters, ActionListener listener) { + try { + String input = parameters.get(INPUT_FIELD); + JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); + String index = jsonObject.get(INDEX_FIELD).getAsString(); + String query = jsonObject.get(QUERY_FIELD).toString(); + query = "{\"query\": " + query + "}"; + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); + + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (SearchHit hit : hits) { + String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map docContent = AbstractRetrieverTool.processResponse(hit); + return StringUtils.gson.toJson(docContent); + }); + contextBuilder.append(doc).append("\n"); + } + listener.onResponse((T) contextBuilder.toString()); + } else { + listener.onResponse((T) ""); + } + }, e -> { + log.error("Failed to search index", e); + listener.onFailure(e); + }); + + // since searching connector and model needs access control, we need + // to forward the request corresponding transport action + if (Objects.equals(index, ML_CONNECTOR_INDEX)) { + client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); + } else if (Objects.equals(index, ML_MODEL_INDEX)) { + client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); + } else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) { + client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener); + } else { + client.search(searchRequest, actionListener); + } + } catch (Exception e) { + log.error("Failed to search index", e); + listener.onFailure(e); + } + } + + public static class Factory implements Tool.Factory { + + private Client client; + private static Factory INSTANCE; + + private NamedXContentRegistry xContentRegistry; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchIndexTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchIndexTool create(Map params) { + return new SearchIndexTool(client, xContentRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java index 5f5803f0..e55a2d8f 100644 --- a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -98,7 +98,7 @@ public void testRunAsyncWithSearchResults() { future.join(); assertEquals( "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" - + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", + + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", future.get() ); } @@ -218,7 +218,7 @@ public void testFactory() { // Create a mock object of the abstract Factory class Client client = mock(Client.class); AbstractRetrieverTool.Factory factoryMock = new AbstractRetrieverTool.Factory<>() { - public PPLTool create(Map params) { + public AbstractRetrieverTool create(Map params) { return null; } }; diff --git a/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java new file mode 100644 index 00000000..f90d2155 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.InputStream; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.search.SearchModule; + +import lombok.SneakyThrows; + +public class SearchIndexToolTests { + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private Client client; + + private SearchIndexTool mockedSearchIndexTool; + + private String mockedSearchResponseString; + + @Before + @SneakyThrows + public void setup() { + client = mock(Client.class); + mockedSearchIndexTool = Mockito + .mock( + SearchIndexTool.class, + Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + + try (InputStream searchResponseIns = SearchIndexTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes()); + } + } + } + + @Test + @SneakyThrows + public void testGetType() { + String type = mockedSearchIndexTool.getType(); + assertFalse(Strings.isNullOrEmpty(type)); + assertEquals("SearchIndexTool", type); + } + + @Test + @SneakyThrows + public void testValidate() { + Map parameters = Map.of("input", "{}"); + assertTrue(mockedSearchIndexTool.validate(parameters)); + } + + @Test + @SneakyThrows + public void testValidateWithEmptyInput() { + Map parameters = Map.of(); + assertFalse(mockedSearchIndexTool.validate(parameters)); + } + + @Test + public void testRunWithNormalIndex() { + String inputString = "{\"index\": \"test-index\", \"query\": {\"match_all\": {}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + public void testRunWithConnectorIndex() { + String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"match_all\": {}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); + } + + @Test + public void testRunWithModelIndex() { + String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"match_all\": {}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithSearchResults() { + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + String inputString = "{\"index\": \"test-index\", \"query\": {\"match_all\": {}}}"; + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, listener); + + future.join(); + + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithEmptyQuery() { + String inputString = "{\"index\": \"test_index\"}"; + Map parameters = Map.of("input", inputString); + ActionListener listener = mock(ActionListener.class); + mockedSearchIndexTool.run(parameters, listener); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + Mockito.verify(client, Mockito.never()).search(any(), any()); + } + + @Test + public void testRunWithInvalidQuery() { + String inputString = "{\"index\": \"test-index\", \"query\": \"invalid query\"}"; + Map parameters = Map.of("input", inputString); + ActionListener listener = mock(ActionListener.class); + mockedSearchIndexTool.run(parameters, listener); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + Mockito.verify(client, Mockito.never()).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ParsingException.class); + // since error message for ParsingException is different, we only need to expect ParsingException to be thrown + verify(listener).onFailure(argumentCaptor.capture()); + } + + @Test + public void testFactory() { + SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchIndexTool.TYPE, searchIndexTool.getType()); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json index 7e66dd60..d89ad3b0 100644 --- a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json +++ b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json @@ -27,7 +27,7 @@ "_id": "2", "_score": 0.10702579, "_source": { - "passage_text": "the price of the api is 2$ per invokation" + "passage_text": "the price of the api is 2$ per invocation" } } ]