From b8890a89b30336fe85077216e696d9284d086524 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 27 Dec 2023 02:16:02 +0000 Subject: [PATCH] Add visualization tool (#41) * Visualization Tool Signed-off-by: Hailong Cui * fix build failure due to forbiddenApis Signed-off-by: Hailong Cui * Address review comments Signed-off-by: Hailong Cui * spotlessApply Signed-off-by: Hailong Cui * update default tool name Signed-off-by: Hailong Cui * update number of visualization be dynamic Signed-off-by: Hailong Cui --------- Signed-off-by: Hailong Cui (cherry picked from commit 3774eb9e477d676b332c469dffc69720b9088d2b) Signed-off-by: github-actions[bot] --- build.gradle | 8 + .../java/org/opensearch/agent/ToolPlugin.java | 10 +- .../agent/tools/VisualizationsTool.java | 171 ++++++++++++++++++ .../agent/tools/VisualizationsToolTests.java | 161 +++++++++++++++++ .../opensearch/agent/tools/visualization.json | 58 ++++++ .../agent/tools/visualization_not_found.json | 18 ++ 6 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/opensearch/agent/tools/VisualizationsTool.java create mode 100644 src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java create mode 100644 src/test/resources/org/opensearch/agent/tools/visualization.json create mode 100644 src/test/resources/org/opensearch/agent/tools/visualization_not_found.json diff --git a/build.gradle b/build.gradle index 48bf0ecd..62debb6f 100644 --- a/build.gradle +++ b/build.gradle @@ -179,6 +179,14 @@ test { systemProperty 'tests.security.manager', 'false' } +jacocoTestReport { + dependsOn test + reports { + html.required = true // human readable + xml.required = true // for coverlay + } +} + spotless { if (JavaVersion.current() >= JavaVersion.VERSION_17) { // Spotless configuration for Java files diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 8e3d0844..5ac1ce57 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -13,6 +13,7 @@ import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.VectorDBTool; +import org.opensearch.agent.tools.VisualizationsTool; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -56,6 +57,7 @@ public Collection createComponents( this.xContentRegistry = xContentRegistry; PPLTool.Factory.getInstance().init(client); + VisualizationsTool.Factory.getInstance().init(client); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); return Collections.emptyList(); @@ -63,6 +65,12 @@ public Collection createComponents( @Override public List> getToolFactories() { - return List.of(PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance()); + return List + .of( + PPLTool.Factory.getInstance(), + NeuralSparseSearchTool.Factory.getInstance(), + VectorDBTool.Factory.getInstance(), + VisualizationsTool.Factory.getInstance() + ); } } diff --git a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java new file mode 100644 index 00000000..31f5cf09 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(VisualizationsTool.TYPE) +public class VisualizationsTool implements Tool { + public static final String NAME = "FindVisualizations"; + public static final String TYPE = "VisualizationTool"; + public static final String VERSION = "v1.0"; + + public static final String SAVED_OBJECT_TYPE = "visualization"; + + /** + * default number of visualizations returned + */ + private static final int DEFAULT_SIZE = 3; + private static final String DEFAULT_DESCRIPTION = + "Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations"; + @Setter + @Getter + private String description = DEFAULT_DESCRIPTION; + + @Getter + @Setter + private String name = NAME; + @Getter + @Setter + private String type = TYPE; + @Getter + private final String version = VERSION; + private final Client client; + @Getter + private final String index; + @Getter + private final int size; + + @Builder + public VisualizationsTool(Client client, String index, int size) { + this.client = client; + this.index = index; + this.size = size; + } + + @Override + public void run(Map parameters, ActionListener listener) { + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE)); + boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input"))); + + SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder); + searchSourceBuilder.from(0).size(3); + SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder); + + client.search(searchRequest, new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + SearchHits hits = searchResponse.getHits(); + StringBuilder visBuilder = new StringBuilder(); + visBuilder.append("Title,Id\n"); + if (hits.getTotalHits().value > 0) { + Arrays.stream(hits.getHits()).forEach(h -> { + String id = trimIdPrefix(h.getId()); + Map visMap = (Map) h.getSourceAsMap().get(SAVED_OBJECT_TYPE); + String title = visMap.get("title"); + visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id)); + }); + + listener.onResponse((T) visBuilder.toString()); + } else { + listener.onResponse((T) "No Visualization found"); + } + } + + @Override + public void onFailure(Exception e) { + if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { + listener.onResponse((T) "No Visualization found"); + } else { + listener.onFailure(e); + } + } + }); + } + + @VisibleForTesting + String trimIdPrefix(String id) { + id = Optional.ofNullable(id).orElse(""); + if (id.startsWith(SAVED_OBJECT_TYPE)) { + String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE); + return id.substring(prefix.length()); + } + return id; + } + + @Override + public boolean validate(Map parameters) { + return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static VisualizationsTool.Factory INSTANCE; + + public static VisualizationsTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (VisualizationsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new VisualizationsTool.Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public VisualizationsTool create(Map params) { + String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); + String sizeStr = params.get("size") == null ? "3" : (String) params.get("size"); + int size; + try { + size = Integer.parseInt(sizeStr); + } catch (NumberFormatException ignored) { + size = DEFAULT_SIZE; + } + return VisualizationsTool.builder().client(client).index(index).size(size).build(); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java b/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java new file mode 100644 index 00000000..9cd79ff9 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.spi.tools.Tool; + +public class VisualizationsToolTests { + @Mock + private Client client; + + private String searchResponse = "{}"; + private String searchResponseNotFound = "{}"; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + VisualizationsTool.Factory.getInstance().init(client); + try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) { + if (searchResponseIns != null) { + searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) { + if (searchResponseIns != null) { + searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + } + + @Test + public void testToolIndexName() { + VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(tool1.getIndex(), ".kibana"); + + VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index")); + assertEquals(tool2.getIndex(), "test-index"); + } + + @Test + public void testNumberOfVisualizationReturned() { + VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(tool1.getSize(), 3); + + VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1")); + assertEquals(tool2.getSize(), 1); + + VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString")); + assertEquals(tool3.getSize(), 3); + } + + @Test + public void testTrimPrefix() { + VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(tool.trimIdPrefix(null), ""); + assertEquals(tool.trimIdPrefix("abc"), "abc"); + assertEquals(tool.trimIdPrefix("visualization:abc"), "abc"); + } + + @Test + public void testParameterValidation() { + VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + Assert.assertFalse(tool.validate(Collections.emptyMap())); + Assert.assertFalse(tool.validate(Map.of("input", ""))); + Assert.assertTrue(tool.validate(Map.of("input", "question"))); + } + + @Test + public void testRunToolWithVisualizationFound() throws Exception { + Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); + + ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); + Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); + + Map params = Map.of("input", "Sales by gender"); + + tool.run(params, listener); + + SearchResponse response = SearchResponse + .fromXContent( + JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse) + ); + searchResponseListener.getValue().onResponse(response); + + future.join(); + assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get()); + } + + @Test + public void testRunToolWithNoVisualizationFound() throws Exception { + Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); + + ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); + Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); + + Map params = Map.of("input", "Sales by gender"); + + tool.run(params, listener); + + SearchResponse response = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound) + ); + searchResponseListener.getValue().onResponse(response); + + future.join(); + assertEquals("No Visualization found", future.get()); + } + + @Test + public void testRunToolWithIndexNotExists() throws Exception { + Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); + + ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); + Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); + + Map params = Map.of("input", "Sales by gender"); + + tool.run(params, listener); + + IndexNotFoundException notFoundException = new IndexNotFoundException("test-index"); + searchResponseListener.getValue().onFailure(notFoundException); + + future.join(); + assertEquals("No Visualization found", future.get()); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/visualization.json b/src/test/resources/org/opensearch/agent/tools/visualization.json new file mode 100644 index 00000000..8901706e --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/visualization.json @@ -0,0 +1,58 @@ +{ + "took": 4, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 1, + "relation": "eq" + }, + "max_score": 0.2847877, + "hits": [ + { + "_index": ".kibana_1", + "_id": "visualization:aeb212e0-4c84-11e8-b3d7-01146121b73d", + "_score": 0.2847877, + "_source": { + "visualization": { + "title": "[Ecommerce]Sales by gender", + "visState": "", + "uiStateJSON": "{}", + "description": "", + "version": 1, + "kibanaSavedObjectMeta": { + "searchSourceJSON": "{}" + } + }, + "type": "visualization", + "references": [ + { + "name": "control_0_index_pattern", + "type": "index-pattern", + "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" + }, + { + "name": "control_1_index_pattern", + "type": "index-pattern", + "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" + }, + { + "name": "control_2_index_pattern", + "type": "index-pattern", + "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" + } + ], + "migrationVersion": { + "visualization": "7.10.0" + }, + "updated_at": "2023-11-10T02:50:24.881Z" + } + } + ] + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json b/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json new file mode 100644 index 00000000..40a0e9d3 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json @@ -0,0 +1,18 @@ +{ + "took": 1, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 0, + "relation": "eq" + }, + "max_score": null, + "hits": [] + } +}