Skip to content

Commit

Permalink
Add a search AD results tool (opensearch-project#52) (opensearch-proj…
Browse files Browse the repository at this point in the history
…ect#67)

(cherry picked from commit 18445e6)

Signed-off-by: Tyler Ohlsen <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent 964b2f0 commit 7f383fe
Show file tree
Hide file tree
Showing 2 changed files with 403 additions and 0 deletions.
210 changes: 210 additions & 0 deletions src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.client.AnomalyDetectionNodeClient;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(SearchAnomalyResultsTool.TYPE)
public class SearchAnomalyResultsTool implements Tool {
public static final String TYPE = "SearchAnomalyResultsTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to search anomaly results.";

@Setter
@Getter
private String name = TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

@Getter
private String version;

private Client client;

private AnomalyDetectionNodeClient adClient;

@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;

public SearchAnomalyResultsTool(Client client) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);

// probably keep this overridden output parser. need to ensure the output matches what's expected
outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
}

// Response is currently in a simple string format including the list of anomaly results (only detector ID, grade, confidence),
// and total # of results. The output will likely need to be updated, standardized, and include more fields in the
// future to cover a sufficient amount of potential questions the agent will need to handle.
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String detectorId = parameters.getOrDefault("detectorId", null);
final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null;
final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold")
? Double.parseDouble(parameters.get("anomalyGradeThreshold"))
: null;
final Long dataStartTime = parameters.containsKey("dataStartTime") && StringUtils.isNumeric(parameters.get("dataStartTime"))
? Long.parseLong(parameters.get("dataStartTime"))
: null;
final Long dataEndTime = parameters.containsKey("dataEndTime") && StringUtils.isNumeric(parameters.get("dataEndTime"))
? Long.parseLong(parameters.get("dataEndTime"))
: null;
final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc");
final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC;
final String sortString = parameters.getOrDefault("sortString", "name.keyword");
final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20;
final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0;

List<QueryBuilder> mustList = new ArrayList<QueryBuilder>();
if (detectorId != null) {
mustList.add(new TermQueryBuilder("detector_id", detectorId));
}
// We include or exclude the task ID if fetching historical or real-time results, respectively.
// For more details, see https://opensearch.org/docs/latest/observing-your-data/ad/api/#search-detector-result
if (realTime != null) {
BoolQueryBuilder boolQuery = new BoolQueryBuilder();
ExistsQueryBuilder existsQuery = new ExistsQueryBuilder("task_id");
if (realTime) {
boolQuery.mustNot(existsQuery);
} else {
boolQuery.must(existsQuery);
}
mustList.add(boolQuery);
}
if (anomalyGradeThreshold != null) {
mustList.add(new RangeQueryBuilder("anomaly_grade").gte(anomalyGradeThreshold));
}
if (dataStartTime != null || dataEndTime != null) {
RangeQueryBuilder rangeQuery = new RangeQueryBuilder("anomaly_grade");
if (dataStartTime != null) {
rangeQuery.gte(dataStartTime);
}
if (dataEndTime != null) {
rangeQuery.lte(dataEndTime);
}
mustList.add(rangeQuery);
}

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must().addAll(mustList);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(boolQueryBuilder)
.size(size)
.from(startIndex)
.sort(sortString, sortOrder);

SearchRequest searchAnomalyResultsRequest = new SearchRequest().source(searchSourceBuilder);

ActionListener<SearchResponse> searchAnomalyResultsListener = ActionListener.<SearchResponse>wrap(response -> {
StringBuilder sb = new StringBuilder();
SearchHit[] hits = response.getHits().getHits();
sb.append("AnomalyResults=[");
for (SearchHit hit : hits) {
sb.append("{");
sb.append("detectorId=").append(hit.getSourceAsMap().get("detector_id")).append(",");
sb.append("grade=").append(hit.getSourceAsMap().get("anomaly_grade")).append(",");
sb.append("confidence=").append(hit.getSourceAsMap().get("confidence"));
sb.append("}");
}
sb.append("]");
sb.append("TotalAnomalyResults=").append(response.getHits().getTotalHits().value);
listener.onResponse((T) sb.toString());
}, e -> { listener.onFailure(e); });

adClient.searchAnomalyResults(searchAnomalyResultsRequest, searchAnomalyResultsListener);
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
}

@Override
public String getType() {
return TYPE;
}

/**
* Factory for the {@link SearchAnomalyResultsTool}
*/
public static class Factory implements Tool.Factory<SearchAnomalyResultsTool> {
private Client client;

private AnomalyDetectionNodeClient adClient;

private static Factory INSTANCE;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchAnomalyResultsTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

/**
* Initialize this factory
* @param client The OpenSearch client
*/
public void init(Client client) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
}

@Override
public SearchAnomalyResultsTool create(Map<String, Object> map) {
return new SearchAnomalyResultsTool(client);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}

}
Loading

0 comments on commit 7f383fe

Please sign in to comment.