diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java new file mode 100644 index 00000000..ef1a44dd --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java @@ -0,0 +1,210 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.client.AnomalyDetectionNodeClient; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.Setter; + +@ToolAnnotation(SearchAnomalyResultsTool.TYPE) +public class SearchAnomalyResultsTool implements Tool { + public static final String TYPE = "SearchAnomalyResultsTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to search anomaly results."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private Client client; + + private AnomalyDetectionNodeClient adClient; + + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAnomalyResultsTool(Client client) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client); + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of anomaly results (only detector ID, grade, confidence), + // and total # of results. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String detectorId = parameters.getOrDefault("detectorId", null); + final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null; + final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold") + ? Double.parseDouble(parameters.get("anomalyGradeThreshold")) + : null; + final Long dataStartTime = parameters.containsKey("dataStartTime") && StringUtils.isNumeric(parameters.get("dataStartTime")) + ? Long.parseLong(parameters.get("dataStartTime")) + : null; + final Long dataEndTime = parameters.containsKey("dataEndTime") && StringUtils.isNumeric(parameters.get("dataEndTime")) + ? Long.parseLong(parameters.get("dataEndTime")) + : null; + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "name.keyword"); + final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20; + final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0; + + List mustList = new ArrayList(); + if (detectorId != null) { + mustList.add(new TermQueryBuilder("detector_id", detectorId)); + } + // We include or exclude the task ID if fetching historical or real-time results, respectively. + // For more details, see https://opensearch.org/docs/latest/observing-your-data/ad/api/#search-detector-result + if (realTime != null) { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + ExistsQueryBuilder existsQuery = new ExistsQueryBuilder("task_id"); + if (realTime) { + boolQuery.mustNot(existsQuery); + } else { + boolQuery.must(existsQuery); + } + mustList.add(boolQuery); + } + if (anomalyGradeThreshold != null) { + mustList.add(new RangeQueryBuilder("anomaly_grade").gte(anomalyGradeThreshold)); + } + if (dataStartTime != null || dataEndTime != null) { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder("anomaly_grade"); + if (dataStartTime != null) { + rangeQuery.gte(dataStartTime); + } + if (dataEndTime != null) { + rangeQuery.lte(dataEndTime); + } + mustList.add(rangeQuery); + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + SearchRequest searchAnomalyResultsRequest = new SearchRequest().source(searchSourceBuilder); + + ActionListener searchAnomalyResultsListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + SearchHit[] hits = response.getHits().getHits(); + sb.append("AnomalyResults=["); + for (SearchHit hit : hits) { + sb.append("{"); + sb.append("detectorId=").append(hit.getSourceAsMap().get("detector_id")).append(","); + sb.append("grade=").append(hit.getSourceAsMap().get("anomaly_grade")).append(","); + sb.append("confidence=").append(hit.getSourceAsMap().get("confidence")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalAnomalyResults=").append(response.getHits().getTotalHits().value); + listener.onResponse((T) sb.toString()); + }, e -> { listener.onFailure(e); }); + + adClient.searchAnomalyResults(searchAnomalyResultsRequest, searchAnomalyResultsListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Factory for the {@link SearchAnomalyResultsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private AnomalyDetectionNodeClient adClient; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAnomalyResultsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client); + } + + @Override + public SearchAnomalyResultsTool create(Map map) { + return new SearchAnomalyResultsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java new file mode 100644 index 00000000..c9d83de2 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class SearchAnomalyResultsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAnomalyResultsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("detectorId", "foo"); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("detectorId", "foo"); + validParams.put("realTime", "true"); + validParams.put("anomalyGradethreshold", "-1"); + validParams.put("dataStartTime", "1234"); + validParams.put("dataEndTime", "5678"); + validParams.put("sortOrder", "AsC"); + validParams.put("sortString", "foo.bar"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + } + + @Test + public void testRunWithInvalidAnomalyGradeParam() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertThrows(NumberFormatException.class, () -> tool.run(Map.of("anomalyGradeThreshold", "foo"), listener)); + } + + @Test + public void testRunWithNoResults() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchHit[] hits = new SearchHit[0]; + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getResultsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String.format(Locale.getDefault(), "AnomalyResults=[]TotalAnomalyResults=%d", hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getResultsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleResult() throws Exception { + final String detectorId = "detector-1-id"; + final double anomalyGrade = 0.5; + final double confidence = 0.9; + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("detector_id", detectorId); + content.field("anomaly_grade", anomalyGrade); + content.field("confidence", confidence); + content.endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, detectorId, null, null).sourceRef(BytesReference.bytes(content)); + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getResultsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String + .format( + "AnomalyResults=[{detectorId=%s,grade=%2.1f,confidence=%2.1f}]TotalAnomalyResults=%d", + detectorId, + anomalyGrade, + confidence, + hits.length + ); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getResultsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testValidate() { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAnomalyResultsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } +}