diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index e1aaa046..823320f0 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -12,6 +12,7 @@ import org.opensearch.agent.common.SkillSettings; import org.opensearch.agent.tools.CreateAlertTool; +import org.opensearch.agent.tools.CreateAnomalyDetectorTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.RAGTool; @@ -75,6 +76,7 @@ public Collection createComponents( SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry); SearchMonitorsTool.Factory.getInstance().init(client); CreateAlertTool.Factory.getInstance().init(client); + CreateAnomalyDetectorTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -90,7 +92,8 @@ public List> getToolFactories() { SearchAnomalyDetectorsTool.Factory.getInstance(), SearchAnomalyResultsTool.Factory.getInstance(), SearchMonitorsTool.Factory.getInstance(), - CreateAlertTool.Factory.getInstance() + CreateAlertTool.Factory.getInstance(), + CreateAnomalyDetectorTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java new file mode 100644 index 00000000..8d4bdb7e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -0,0 +1,451 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import com.google.common.collect.ImmutableMap; + +import joptsimple.internal.Strings; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * A tool used to help creating anomaly detector, the only one input parameter is the index name, this tool will get the mappings of the index + * in flight and let LLM give the suggested category field, aggregation field and correspond aggregation method which are required for the create + * anomaly detector API, the output of this tool is like: + *{ + * "index": "opensearch_dashboards_sample_data_ecommerce", + * "categoryField": "geoip.country_iso_code", + * "aggregationField": "total_quantity,total_unique_products,taxful_total_price", + * "aggregationMethod": "sum,count,sum", + * "dateFields": "customer_birth_date,order_date,products.created_on" + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(CreateAnomalyDetectorTool.TYPE) +public class CreateAnomalyDetectorTool implements Tool { + // the type of this tool + public static final String TYPE = "CreateAnomalyDetectorTool"; + + // the default description of this tool + private static final String DEFAULT_DESCRIPTION = + "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector."; + // the regex used to extract the key information from the response of LLM + private static final String EXTRACT_INFORMATION_REGEX = + "(?s).*\\{category_field=([^|]*)\\|aggregation_field=([^|]*)\\|aggregation_method=([^}]*)}.*"; + // valid field types which support aggregation + private static final Set VALID_FIELD_TYPES = Set + .of( + "keyword", + "constant_keyword", + "wildcard", + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "unsigned_long", + "ip" + ); + // the index name key in the output + private static final String OUTPUT_KEY_INDEX = "index"; + // the category field key in the output + private static final String OUTPUT_KEY_CATEGORY_FIELD = "categoryField"; + // the aggregation field key in the output + private static final String OUTPUT_KEY_AGGREGATION_FIELD = "aggregationField"; + // the aggregation method name key in the output + private static final String OUTPUT_KEY_AGGREGATION_METHOD = "aggregationMethod"; + // the date fields key in the output + private static final String OUTPUT_KEY_DATE_FIELDS = "dateFields"; + // the default prompt dictionary, includes claude and openai + private static final Map DEFAULT_PROMPT_DICT = loadDefaultPromptFromFile(); + // the name of this tool + @Setter + @Getter + private String name = TYPE; + // the description of this tool + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + // the version of this tool + @Getter + private String version; + + // the OpenSearch transport client + private Client client; + // the mode id of LLM + private String modelId; + // LLM model type, CLAUDE or OPENAI + private ModelType modelType; + // the default prompt for creating anomaly detector + private String contextPrompt; + + enum ModelType { + CLAUDE, + OPENAI; + + public static ModelType from(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + + } + + /** + * + * @param client the OpenSearch transport client + * @param modelId the model ID of LLM + */ + public CreateAnomalyDetectorTool(Client client, String modelId, String modelType) { + this.client = client; + this.modelId = modelId; + if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + this.modelType = ModelType.from(modelType); + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), ""); + } + + /** + * The main running method of this tool + * @param parameters the input parameters + * @param listener the action listener + * + */ + @Override + public void run(Map parameters, ActionListener listener) { + Map enrichedParameters = enrichParameters(parameters); + String indexName = enrichedParameters.get("index"); + if (Strings.isNullOrEmpty(indexName)) { + throw new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ); + } + if (indexName.startsWith(".")) { + throw new IllegalArgumentException( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + indexName + ); + } + + GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(indexName); + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { + Map mappings = response.getMappings(); + if (mappings.size() == 0) { + throw new IllegalArgumentException("No mapping found for the index: " + indexName); + } + + MappingMetadata mappingMetadata; + // when the index name is wildcard pattern, we fetch the mappings of the first index + if (indexName.contains("*")) { + mappingMetadata = mappings.get((String) mappings.keySet().toArray()[0]); + } else { + mappingMetadata = mappings.get(indexName); + } + + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + if (Objects.isNull(mappingSource)) { + throw new IllegalArgumentException( + "The index " + indexName + " doesn't have mapping metadata, please add data to it or using another index." + ); + } + + // flatten all the fields in the mapping + Map fieldsToType = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, ""); + + // find all date type fields from the mapping + final Set dateFields = findDateTypeFields(fieldsToType); + if (dateFields.isEmpty()) { + throw new IllegalArgumentException( + "The index " + indexName + " doesn't have date type fields, cannot create an anomaly detector for it." + ); + } + StringJoiner dateFieldsJoiner = new StringJoiner(","); + dateFields.forEach(dateFieldsJoiner::add); + + // filter the mapping to improve the accuracy of the result + // only fields support aggregation can be existed in the mapping and sent to LLM + Map filteredMapping = fieldsToType + .entrySet() + .stream() + .filter(entry -> VALID_FIELD_TYPES.contains(entry.getValue())) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + // construct the prompt + String prompt = constructPrompt(filteredMapping, indexName); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Collections.singletonMap("prompt", prompt)) + .build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + null + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0); + ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); + Map dataAsMap = (Map) modelTensor.getDataAsMap(); + if (dataAsMap == null) { + listener.onFailure(new IllegalStateException("Remote endpoint fails to inference.")); + return; + } + String finalResponse = dataAsMap.get("response"); + if (Strings.isNullOrEmpty(finalResponse)) { + listener.onFailure(new IllegalStateException("Remote endpoint fails to inference, no response found.")); + return; + } + + // use regex pattern to extract the suggested parameters for the create anomaly detector API + Pattern pattern = Pattern.compile(EXTRACT_INFORMATION_REGEX); + Matcher matcher = pattern.matcher(finalResponse); + if (!matcher.matches()) { + log + .error( + "The inference result from remote endpoint is not valid because the result: [" + + finalResponse + + "] cannot match the regex: " + + EXTRACT_INFORMATION_REGEX + ); + listener + .onFailure( + new IllegalStateException( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result." + ) + ); + return; + } + + // remove double quotes or whitespace if exists + String categoryField = matcher.group(1).replaceAll("\"", "").strip(); + String aggregationField = matcher.group(2).replaceAll("\"", "").strip(); + String aggregationMethod = matcher.group(3).replaceAll("\"", "").strip(); + + Map result = ImmutableMap + .of( + OUTPUT_KEY_INDEX, + indexName, + OUTPUT_KEY_CATEGORY_FIELD, + categoryField, + OUTPUT_KEY_AGGREGATION_FIELD, + aggregationField, + OUTPUT_KEY_AGGREGATION_METHOD, + aggregationMethod, + OUTPUT_KEY_DATE_FIELDS, + dateFieldsJoiner.toString() + ); + listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(result))); + }, e -> { + log.error("fail to predict model: " + e); + listener.onFailure(e); + })); + }, e -> { + log.error("failed to get mapping: " + e); + if (e.toString().contains("IndexNotFoundException")) { + listener + .onFailure( + new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'The index doesn't exist, please provide another index and retry'. Please try to directly send this message to human to ask for index name" + ) + ); + } else { + listener.onFailure(e); + } + })); + } + + /** + * Enrich the parameters by adding the parameters extracted from the chat + * @param parameters the original parameters + * @return the enriched parameters with parameters extracting from the chat + */ + private Map enrichParameters(Map parameters) { + Map result = new HashMap<>(parameters); + try { + // input is a map + Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + result.putAll(chatParameters); + } catch (Exception e) { + // input is a string + String indexName = parameters.getOrDefault("input", ""); + if (!indexName.isEmpty()) { + result.put("index", indexName); + } + } + return result; + } + + /** + * + * @param fieldsToType the flattened field-> field type mapping + * @return a list containing all the date type fields + */ + private Set findDateTypeFields(final Map fieldsToType) { + Set result = new HashSet<>(); + for (Map.Entry entry : fieldsToType.entrySet()) { + String value = entry.getValue(); + if (value.equals("date") || value.equals("date_nanos")) { + result.add(entry.getKey()); + } + } + return result; + } + + @SuppressWarnings("unchecked") + private static Map loadDefaultPromptFromFile() { + try (InputStream inputStream = CreateAnomalyDetectorTool.class.getResourceAsStream("CreateAnomalyDetectorDefaultPrompt.json")) { + if (inputStream != null) { + return gson.fromJson(new String(inputStream.readAllBytes(), StandardCharsets.UTF_8), Map.class); + } + } catch (IOException e) { + log.error("Failed to load prompt from the file CreateAnomalyDetectorDefaultPrompt.json, error: ", e); + } + return new HashMap<>(); + } + + /** + * + * @param fieldsToType the flattened field-> field type mapping + * @param indexName the index name + * @return the prompt about creating anomaly detector + */ + private String constructPrompt(final Map fieldsToType, final String indexName) { + StringJoiner tableInfoJoiner = new StringJoiner("\n"); + for (Map.Entry entry : fieldsToType.entrySet()) { + tableInfoJoiner.add("- " + entry.getKey() + ": " + entry.getValue()); + } + + Map indexInfo = ImmutableMap.of("indexName", indexName, "indexMapping", tableInfoJoiner.toString()); + StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); + return substitutor.replace(contextPrompt); + } + + /** + * + * @param parameters the input parameters + * @return false if the input parameters is null or empty + */ + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.size() != 0; + } + + /** + * + * @return the type of this tool + */ + @Override + public String getType() { + return TYPE; + } + + /** + * The tool factory + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static CreateAnomalyDetectorTool.Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static CreateAnomalyDetectorTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CreateAnomalyDetectorTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new CreateAnomalyDetectorTool.Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + /** + * + * @param map the input parameters + * @return the instance of this tool + */ + @Override + public CreateAnomalyDetectorTool create(Map map) { + String modelId = (String) map.getOrDefault("model_id", ""); + if (modelId.isEmpty()) { + throw new IllegalArgumentException("model_id cannot be empty."); + } + String modelType = (String) map.getOrDefault("model_type", ModelType.CLAUDE.toString()); + if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + return new CreateAnomalyDetectorTool(client, modelId, modelType); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index dc1d62b4..82787836 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -33,6 +33,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.agent.common.SkillSettings; import org.opensearch.agent.tools.utils.ClusterSettingHelper; +import org.opensearch.agent.tools.utils.ToolHelper; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; @@ -401,7 +402,7 @@ private String constructTableInfo(SearchHit[] searchHits, Map fieldsToType = new HashMap<>(); - extractNamesTypes(mappingSource, fieldsToType, ""); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, ""); StringJoiner tableInfoJoiner = new StringJoiner("\n"); List sortedKeys = new ArrayList<>(fieldsToType.keySet()); Collections.sort(sortedKeys); @@ -439,28 +440,6 @@ private String constructPrompt(String tableInfo, String question, String indexNa return substitutor.replace(contextPrompt); } - private void extractNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { - if (!prefix.isEmpty()) { - prefix += "."; - } - - for (Map.Entry entry : mappingSource.entrySet()) { - String n = entry.getKey(); - Object v = entry.getValue(); - - if (v instanceof Map) { - Map vMap = (Map) v; - if (vMap.containsKey("type")) { - if (!((vMap.getOrDefault("type", "")).equals("alias"))) { - fieldsToType.put(prefix + n, (String) vMap.get("type")); - } - } else if (vMap.containsKey("properties")) { - extractNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); - } - } - } - } - private static void extractSamples(Map sampleSource, Map fieldsToSample, String prefix) throws PrivilegedActionException { if (!prefix.isEmpty()) { diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java index c7d82d4b..82ad7172 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -30,4 +30,36 @@ public static Map loadDefaultPromptDictFromFile(Class source, } return new HashMap<>(); } + + /** + * Flatten all the fields in the mappings, insert the field->field type mapping to a map + * @param mappingSource the mappings of an index + * @param fieldsToType the result containing the field->field type mapping + * @param prefix the parent field path + */ + public static void extractFieldNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { + if (prefix.length() > 0) { + prefix += "."; + } + + for (Map.Entry entry : mappingSource.entrySet()) { + String n = entry.getKey(); + Object v = entry.getValue(); + + if (v instanceof Map) { + Map vMap = (Map) v; + if (vMap.containsKey("type")) { + if (!((vMap.getOrDefault("type", "")).equals("alias"))) { + fieldsToType.put(prefix + n, (String) vMap.get("type")); + } + } + if (vMap.containsKey("properties")) { + extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); + } + if (vMap.containsKey("fields")) { + extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n); + } + } + } + } } diff --git a/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json new file mode 100644 index 00000000..9b69bce7 --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json @@ -0,0 +1,4 @@ +{ + "CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"", + "OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. " +} diff --git a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java new file mode 100644 index 00000000..0749ab70 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; + +import com.google.common.collect.ImmutableMap; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateAnomalyDetectorToolTests { + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + private Map mockedMappings; + private Map indexMappings; + + @Mock + private MLTaskResponse mlTaskResponse; + @Mock + private ModelTensorOutput modelTensorOutput; + @Mock + private ModelTensors modelTensors; + + private ModelTensor modelTensor; + + private Map modelReturns; + + private String mockedIndexName = "http_logs"; + private String mockedResponse = "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + private String mockedResult = + "{\"index\":\"http_logs\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}"; + + private String mockedResultForIndexPattern = + "{\"index\":\"http_logs*\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + createMappings(); + // get mapping + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + initMLTensors(); + CreateAnomalyDetectorTool.Factory.getInstance().init(client); + } + + @Test + public void testModelIdIsNullOrEmpty() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "")) + ); + assertEquals("model_id cannot be empty.", exception.getMessage()); + } + + @Test + public void testModelType() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "unknown")) + ); + assertEquals("Unsupported model_type: unknown", exception.getMessage()); + + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "openai")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("OPENAI", tool.getModelType().toString()); + + tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + } + + @Test + public void testTool() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + + tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + tool + .run( + ImmutableMap.of("index", mockedIndexName + "*"), + ActionListener.wrap(response -> assertEquals(mockedResultForIndexPattern, response), log::info) + ); + tool + .run( + ImmutableMap.of("input", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("index", mockedIndexName))), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + } + + @Test + public void testToolWithInvalidResponse() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + + modelReturns = Collections.singletonMap("response", ""); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + Exception exception = assertThrows( + IllegalStateException.class, + () -> tool + .run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(response -> assertEquals(response, ""), e -> { + throw new IllegalStateException(e.getMessage()); + })) + ); + assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage()); + + modelReturns = Collections.singletonMap("response", "not valid response"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + exception = assertThrows( + IllegalStateException.class, + () -> tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(response, "not valid response"), e -> { + throw new IllegalStateException(e.getMessage()); + }) + ) + ); + assertEquals( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result.", + exception.getMessage() + ); + + modelReturns = Collections.singletonMap("response", null); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + exception = assertThrows( + IllegalStateException.class, + () -> tool + .run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(response -> assertEquals(response, ""), e -> { + throw new IllegalStateException(e.getMessage()); + })) + ); + assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage()); + } + + @Test + public void testToolWithSystemIndex() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool.run(ImmutableMap.of("index", ML_CONNECTOR_INDEX), ActionListener.wrap(result -> {}, e -> {})) + ); + assertEquals( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + ML_CONNECTOR_INDEX, + exception.getMessage() + ); + } + + @Test + public void testToolWithGetMappingFailed() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new Exception("No mapping found for the index: " + mockedIndexName)); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(result -> {}, e -> { + assertEquals("No mapping found for the index: " + mockedIndexName, e.getMessage()); + })); + } + + @Test + public void testToolWithPredictModelFailed() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("predict model failed")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(result -> {}, e -> { + assertEquals("predict model failed", e.getMessage()); + })); + } + + private void createMappings() { + indexMappings = new HashMap<>(); + indexMappings + .put( + "properties", + ImmutableMap + .of( + "response", + ImmutableMap.of("type", "integer"), + "responseLatency", + ImmutableMap.of("type", "float"), + "date", + ImmutableMap.of("type", "date") + ) + ); + mockedMappings = new HashMap<>(); + mockedMappings.put(mockedIndexName, mappingMetadata); + + modelReturns = Collections.singletonMap("response", mockedResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + } + + private void initMLTensors() { + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + + // call model + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java new file mode 100644 index 00000000..648a381b --- /dev/null +++ b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java @@ -0,0 +1,345 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.hamcrest.MatcherAssert; +import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class CreateAnomalyDetectorToolIT extends ToolIntegrationTest { + private final String NORMAL_INDEX = "http_logs"; + private final String NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS = "products"; + private final String NORMAL_INDEX_WITH_NO_DATE_FIELDS = "normal_index_with_no_date_fields"; + private final String NORMAL_INDEX_WITH_NO_MAPPING = "normal_index_with_no_mapping"; + private final String ABNORMAL_INDEX = "abnormal_index"; + + @Override + List promptHandlers() { + PromptHandler createAnomalyDetectorToolHandler = new PromptHandler() { + @Override + String response(String prompt) { + int flag; + if (prompt.contains(NORMAL_INDEX)) { + flag = randomIntBetween(0, 9); + switch (flag) { + case 0: + return "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 1: + return "{category_field=ip|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 2: + return "{category_field=|aggregation_field=responseLatency|aggregation_method=avg}"; + case 3: + return "{category_field=country.keyword|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 4: + return "{category_field=country.keyword|aggregation_field=response.keyword|aggregation_method=count}"; + case 5: + return "{category_field=\"country.keyword\"|aggregation_field=\"response,responseLatency\"|aggregation_method=\"count,avg\"}"; + case 6: + return "{category_field=ip|aggregation_field=responseLatency|aggregation_method=avg}"; + case 7: + return "{category_field=\"ip\"|aggregation_field=\"responseLatency\"|aggregation_method=\"avg\"}"; + case 8: + return "{category_field= ip |aggregation_field= responseLatency |aggregation_method= avg }"; + case 9: + return "{category_field=\" ip \"|aggregation_field=\" responseLatency \"|aggregation_method=\" avg \"}"; + default: + return "{category_field=|aggregation_field=response|aggregation_method=count}"; + } + } else if (prompt.contains(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS)) { + flag = randomIntBetween(0, 9); + switch (flag) { + case 0: + return "{category_field=|aggregation_field=|aggregation_method=}"; + case 1: + return "{category_field= |aggregation_field= |aggregation_method= }"; + case 2: + return "{category_field=\"\"|aggregation_field=\"\"|aggregation_method=\"\"}"; + case 3: + return "{category_field=product|aggregation_field=|aggregation_method=sum}"; + case 4: + return "{category_field=product|aggregation_field=sales|aggregation_method=}"; + case 5: + return "{category_field=product|aggregation_field=\"\"|aggregation_method=sum}"; + case 6: + return "{category_field=product|aggregation_field=sales|aggregation_method=\"\"}"; + case 7: + return "{category_field=product|aggregation_field= |aggregation_method=sum}"; + case 8: + return "{category_field=product|aggregation_field=sales |aggregation_method= }"; + case 9: + return "{category_field=\"\"|aggregation_field= |aggregation_method=\"\" }"; + default: + return "{category_field=product|aggregation_field= |aggregation_method= }"; + } + } else { + flag = randomIntBetween(0, 1); + switch (flag) { + case 0: + return "wrong response"; + case 1: + return "{category_field=product}"; + default: + return "{category_field=}"; + } + } + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(createAnomalyDetectorToolHandler); + } + + @Override + String toolType() { + return CreateAnomalyDetectorTool.TYPE; + } + + public void testCreateAnomalyDetectorTool() { + prepareIndex(); + String agentId = registerAgent(); + String index; + if (randomIntBetween(0, 1) == 0) { + index = NORMAL_INDEX; + } else { + index = NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS; + } + String result = executeAgent(agentId, "{\"parameters\": {\"index\":\"" + index + "\"}}"); + assertTrue(result.contains("index")); + assertTrue(result.contains("categoryField")); + assertTrue(result.contains("aggregationField")); + assertTrue(result.contains("aggregationMethod")); + assertTrue(result.contains("dateFields")); + } + + public void testCreateAnomalyDetectorToolWithNonExistentModelId() { + prepareIndex(); + String agentId = registerAgentWithWrongModelId(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + ABNORMAL_INDEX + "\"}}") + ); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"))); + } + + public void testCreateAnomalyDetectorToolWithUnexpectedResult() { + prepareIndex(); + String agentId = registerAgent(); + + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + NORMAL_INDEX_WITH_NO_MAPPING + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The index " + + NORMAL_INDEX_WITH_NO_MAPPING + + " doesn't have mapping metadata, please add data to it or using another index." + ) + ) + ); + + exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + NORMAL_INDEX_WITH_NO_DATE_FIELDS + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The index " + + NORMAL_INDEX_WITH_NO_DATE_FIELDS + + " doesn't have date type fields, cannot create an anomaly detector for it." + ) + ) + ); + + exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + ABNORMAL_INDEX + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result." + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithSystemIndex() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\": \".test\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: .test" + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithMissingIndex() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"non-existent\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'The index doesn't exist, please provide another index and retry'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithEmptyInput() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {}}")); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + NORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"responseLatency\": {\n" + + " \"type\": \"float\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX, "0", List.of("response", "responseLatency", "date"), List.of(200, 0.15, "2024-07-03T10:22:56,520")); + addDocToIndex(NORMAL_INDEX, "1", List.of("response", "responseLatency", "date"), List.of(200, 3.15, "2024-07-03T10:22:57,520")); + + createIndexWithConfiguration( + NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"product\": {\n" + + " " + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, "0", List.of("product", "date"), List.of(1, "2024-07-03T10:22:56,520")); + addDocToIndex(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, "1", List.of("product", "date"), List.of(2, "2024-07-03T10:22:57,520")); + + createIndexWithConfiguration( + NORMAL_INDEX_WITH_NO_DATE_FIELDS, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"product\": {\n" + + " " + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "0", List.of("product"), List.of(1)); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "1", List.of("product"), List.of(2)); + + createIndexWithConfiguration(NORMAL_INDEX_WITH_NO_MAPPING, "{}"); + + createIndexWithConfiguration( + ABNORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"date\": {\n" + + " " + + " \"type\": \"date\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(ABNORMAL_INDEX, "0", List.of("date"), List.of(1, "2024-07-03T10:22:56,520")); + addDocToIndex(ABNORMAL_INDEX, "1", List.of("date"), List.of(2, "2024-07-03T10:22:57,520")); + } + + @SneakyThrows + private String registerAgentWithWrongModelId() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", "non-existent"); + return createAgent(registerAgentRequestBody); + } + + @SneakyThrows + private String registerAgent() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + return createAgent(registerAgentRequestBody); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json new file mode 100644 index 00000000..3ad9477e --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json @@ -0,0 +1,12 @@ +{ + "name": "Test_create_anomaly_detector_flow_agent", + "type": "flow", + "tools": [ + { + "type": "CreateAnomalyDetectorTool", + "parameters": { + "model_id": "" + } + } + ] +}