Skip to content

Commit

Permalink
Add CreateAnomalyDetectorTool (opensearch-project#348)
Browse files Browse the repository at this point in the history
* Add CreateAnomalyDetectorTool

Signed-off-by: gaobinlong <[email protected]>

* Optimize some code

Signed-off-by: gaobinlong <[email protected]>

* Fix test failure

Signed-off-by: gaobinlong <[email protected]>

* Optimize exception

Signed-off-by: gaobinlong <[email protected]>

---------

Signed-off-by: gaobinlong <[email protected]>
  • Loading branch information
gaobinlong authored and qianheng-aws committed Jul 25, 2024
1 parent 728d778 commit 4c7d70d
Show file tree
Hide file tree
Showing 8 changed files with 1,130 additions and 24 deletions.
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.CreateAlertTool;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.RAGTool;
Expand Down Expand Up @@ -75,6 +76,7 @@ public Collection<Object> createComponents(
SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
return Collections.emptyList();
}

Expand All @@ -90,7 +92,8 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance()
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
);
}

Expand Down

Large diffs are not rendered by default.

25 changes: 2 additions & 23 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -401,7 +402,7 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
);
}
Map<String, String> fieldsToType = new HashMap<>();
extractNamesTypes(mappingSource, fieldsToType, "");
ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "");
StringJoiner tableInfoJoiner = new StringJoiner("\n");
List<String> sortedKeys = new ArrayList<>(fieldsToType.keySet());
Collections.sort(sortedKeys);
Expand Down Expand Up @@ -439,28 +440,6 @@ private String constructPrompt(String tableInfo, String question, String indexNa
return substitutor.replace(contextPrompt);
}

private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (!prefix.isEmpty()) {
prefix += ".";
}

for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();

if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
if (!((vMap.getOrDefault("type", "")).equals("alias"))) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
} else if (vMap.containsKey("properties")) {
extractNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n);
}
}
}
}

private static void extractSamples(Map<String, Object> sampleSource, Map<String, String> fieldsToSample, String prefix)
throws PrivilegedActionException {
if (!prefix.isEmpty()) {
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,36 @@ public static Map<String, String> loadDefaultPromptDictFromFile(Class<?> source,
}
return new HashMap<>();
}

/**
* Flatten all the fields in the mappings, insert the field->field type mapping to a map
* @param mappingSource the mappings of an index
* @param fieldsToType the result containing the field->field type mapping
* @param prefix the parent field path
*/
public static void extractFieldNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (prefix.length() > 0) {
prefix += ".";
}

for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();

if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
if (!((vMap.getOrDefault("type", "")).equals("alias"))) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
}
if (vMap.containsKey("properties")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n);
}
if (vMap.containsKey("fields")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("fields"), fieldsToType, prefix + n);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.IndicesAdminClient;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;

import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class CreateAnomalyDetectorToolTests {
@Mock
private Client client;
@Mock
private AdminClient adminClient;
@Mock
private IndicesAdminClient indicesAdminClient;
@Mock
private GetMappingsResponse getMappingsResponse;
@Mock
private MappingMetadata mappingMetadata;
private Map<String, MappingMetadata> mockedMappings;
private Map<String, Object> indexMappings;

@Mock
private MLTaskResponse mlTaskResponse;
@Mock
private ModelTensorOutput modelTensorOutput;
@Mock
private ModelTensors modelTensors;

private ModelTensor modelTensor;

private Map<String, ?> modelReturns;

private String mockedIndexName = "http_logs";
private String mockedResponse = "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg}";
private String mockedResult =
"{\"index\":\"http_logs\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}";

private String mockedResultForIndexPattern =
"{\"index\":\"http_logs*\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}";

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
createMappings();
// get mapping
when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings);
when(getMappingsResponse.getMappings()).thenReturn(mockedMappings);
when(client.admin()).thenReturn(adminClient);
when(adminClient.indices()).thenReturn(indicesAdminClient);
doAnswer(invocation -> {
ActionListener<GetMappingsResponse> listener = (ActionListener<GetMappingsResponse>) invocation.getArguments()[1];
listener.onResponse(getMappingsResponse);
return null;
}).when(indicesAdminClient).getMappings(any(), any());

initMLTensors();
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
}

@Test
public void testModelIdIsNullOrEmpty() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", ""))
);
assertEquals("model_id cannot be empty.", exception.getMessage());
}

@Test
public void testModelType() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "unknown"))
);
assertEquals("Unsupported model_type: unknown", exception.getMessage());

CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "openai"));
assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("OPENAI", tool.getModelType().toString());

tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude"));
assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("CLAUDE", tool.getModelType().toString());
}

@Test
public void testTool() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId"));
assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("CLAUDE", tool.getModelType().toString());

tool
.run(
ImmutableMap.of("index", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(mockedResult, response), log::info)
);
tool
.run(
ImmutableMap.of("index", mockedIndexName + "*"),
ActionListener.<String>wrap(response -> assertEquals(mockedResultForIndexPattern, response), log::info)
);
tool
.run(
ImmutableMap.of("input", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(mockedResult, response), log::info)
);
tool
.run(
ImmutableMap.of("input", gson.toJson(ImmutableMap.of("index", mockedIndexName))),
ActionListener.<String>wrap(response -> assertEquals(mockedResult, response), log::info)
);
}

@Test
public void testToolWithInvalidResponse() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId"));

modelReturns = Collections.singletonMap("response", "");
modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns);
initMLTensors();

Exception exception = assertThrows(
IllegalStateException.class,
() -> tool
.run(ImmutableMap.of("index", mockedIndexName), ActionListener.<String>wrap(response -> assertEquals(response, ""), e -> {
throw new IllegalStateException(e.getMessage());
}))
);
assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage());

modelReturns = Collections.singletonMap("response", "not valid response");
modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns);
initMLTensors();

exception = assertThrows(
IllegalStateException.class,
() -> tool
.run(
ImmutableMap.of("index", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(response, "not valid response"), e -> {
throw new IllegalStateException(e.getMessage());
})
)
);
assertEquals(
"The inference result from remote endpoint is not valid, cannot extract the key information from the result.",
exception.getMessage()
);

modelReturns = Collections.singletonMap("response", null);
modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns);
initMLTensors();

exception = assertThrows(
IllegalStateException.class,
() -> tool
.run(ImmutableMap.of("index", mockedIndexName), ActionListener.<String>wrap(response -> assertEquals(response, ""), e -> {
throw new IllegalStateException(e.getMessage());
}))
);
assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage());
}

@Test
public void testToolWithSystemIndex() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId"));
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> tool.run(ImmutableMap.of("index", ML_CONNECTOR_INDEX), ActionListener.<String>wrap(result -> {}, e -> {}))
);
assertEquals(
"CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: "
+ ML_CONNECTOR_INDEX,
exception.getMessage()
);
}

@Test
public void testToolWithGetMappingFailed() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId"));
doAnswer(invocation -> {
ActionListener<GetMappingsResponse> listener = (ActionListener<GetMappingsResponse>) invocation.getArguments()[1];
listener.onFailure(new Exception("No mapping found for the index: " + mockedIndexName));
return null;
}).when(indicesAdminClient).getMappings(any(), any());

tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.<String>wrap(result -> {}, e -> {
assertEquals("No mapping found for the index: " + mockedIndexName, e.getMessage());
}));
}

@Test
public void testToolWithPredictModelFailed() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId"));
doAnswer(invocation -> {
ActionListener<MLTaskResponse> listener = (ActionListener<MLTaskResponse>) invocation.getArguments()[2];
listener.onFailure(new Exception("predict model failed"));
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.<String>wrap(result -> {}, e -> {
assertEquals("predict model failed", e.getMessage());
}));
}

private void createMappings() {
indexMappings = new HashMap<>();
indexMappings
.put(
"properties",
ImmutableMap
.of(
"response",
ImmutableMap.of("type", "integer"),
"responseLatency",
ImmutableMap.of("type", "float"),
"date",
ImmutableMap.of("type", "date")
)
);
mockedMappings = new HashMap<>();
mockedMappings.put(mockedIndexName, mappingMetadata);

modelReturns = Collections.singletonMap("response", mockedResponse);
modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns);
}

private void initMLTensors() {
when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor));
when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors));
when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput);

// call model
doAnswer(invocation -> {
ActionListener<MLTaskResponse> listener = (ActionListener<MLTaskResponse>) invocation.getArguments()[2];
listener.onResponse(mlTaskResponse);
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
}
}
Loading

0 comments on commit 4c7d70d

Please sign in to comment.