Skip to content

Commit

Permalink
Merge pull request #39 from ohltyler/ad-tools
Browse files Browse the repository at this point in the history
Add basic search detectors tool; pull plugin deps in gradle run
  • Loading branch information
zane-neo committed Dec 22, 2023
2 parents c037b2a + 00f35b0 commit 9c08646
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 15 deletions.
67 changes: 52 additions & 15 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ buildscript {
opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT")
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")
version_tokens = opensearch_version.tokenize('-')
opensearch_build = version_tokens[0] + '.0'
if (buildVersionQualifier) {
opensearch_build += "-${buildVersionQualifier}"
}
if (isSnapshot) {
opensearch_build += "-SNAPSHOT"
}
}

repositories {
Expand Down Expand Up @@ -77,6 +69,8 @@ apply plugin: 'opensearch.testclusters'
apply plugin: 'opensearch.pluginzip'

def sqlJarDirectory = "$buildDir/dependencies/opensearch-sql-plugin"
def jsJarDirectory = "$buildDir/dependencies/opensearch-job-scheduler"
def adJarDirectory = "$buildDir/dependencies/opensearch-time-series-analytics"

configurations {
zipArchive
Expand All @@ -96,28 +90,52 @@ task addJarsToClasspath(type: Copy) {
include "protocol-${version}.jar"
}
into("$buildDir/classes")

from(fileTree(dir: jsJarDirectory)) {
include "opensearch-job-scheduler-${version}.jar"
}
into("$buildDir/classes")

from(fileTree(dir: adJarDirectory)) {
include "opensearch-time-series-analytics-${version}.jar"
}
into("$buildDir/classes")
}

dependencies {
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
// 3P dependencies
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0"
compileOnly group: 'org.json', name: 'json', version: '20231013'
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"
implementation("com.google.guava:guava:32.1.3-jre")
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0'

// Plugin dependencies
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
implementation fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${version}.jar"])
implementation fileTree(dir: adJarDirectory, include: ["opensearch-time-series-analytics-${version}.jar"])
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"])
compileOnly "org.opensearch:common-utils:${version}"

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${version}"
zipArchive "org.opensearch.plugin:opensearch-anomaly-detection:${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"

// Test dependencies
testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation "org.mockito:mockito-core:5.8.0"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0'
testImplementation("net.bytebuddy:byte-buddy:1.14.7")
testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7")
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1'
testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0'
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0"
testImplementation "com.cronutils:cron-utils:9.2.1"
testImplementation "commons-validator:commons-validator:1.8.0"
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.10.1'

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
}

task extractSqlJar(type: Copy) {
Expand All @@ -126,7 +144,21 @@ task extractSqlJar(type: Copy) {
into sqlJarDirectory
}

task extractJsJar(type: Copy) {
mustRunAfter()
from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-job-scheduler")}))
into jsJarDirectory
}

task extractAdJar(type: Copy) {
mustRunAfter()
from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-anomaly-detection")}))
into adJarDirectory
}

tasks.addJarsToClasspath.dependsOn(extractSqlJar)
tasks.addJarsToClasspath.dependsOn(extractJsJar)
tasks.addJarsToClasspath.dependsOn(extractAdJar)
project.tasks.delombok.dependsOn(addJarsToClasspath)
tasks.publishNebulaPublicationToMavenLocal.dependsOn ':generatePomFileForPluginZipPublication'
tasks.validateNebulaPom.dependsOn ':generatePomFileForPluginZipPublication'
Expand All @@ -137,12 +169,13 @@ testingConventions.enabled = false
thirdPartyAudit.enabled = false

test {
useJUnitPlatform()
testLogging {
exceptionFormat "full"
events "skipped", "passed", "failed" // "started"
showStandardStreams true
}
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
}

spotless {
Expand All @@ -161,6 +194,8 @@ spotless {

compileJava {
dependsOn extractSqlJar
dependsOn extractJsJar
dependsOn extractAdJar
dependsOn delombok
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}
Expand All @@ -169,6 +204,8 @@ compileTestJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}

forbiddenApisTest.ignoreFailures = true


opensearchplugin {
name 'skills'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.client.AnomalyDetectionNodeClient;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.WildcardQueryBuilder;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(SearchAnomalyDetectorsTool.TYPE)
public class SearchAnomalyDetectorsTool implements Tool {
public static final String TYPE = "SearchAnomalyDetectorsTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to search anomaly detectors.";

@Setter
@Getter
private String name = TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

@Getter
private String version;

private Client client;

private AnomalyDetectionNodeClient adClient;

@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;

public SearchAnomalyDetectorsTool(Client client) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);

// probably keep this overridden output parser. need to ensure the output matches what's expected
outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
}

// Response is currently in a simple string format including the list of anomaly detectors (only name and ID attached), and
// number of total detectors. The output will likely need to be updated, standardized, and include more fields in the
// future to cover a sufficient amount of potential questions the agent will need to handle.
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String detectorName = parameters.getOrDefault("detectorName", null);
final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null);
final String indices = parameters.getOrDefault("indices", null);
final Boolean highCardinality = parameters.containsKey("highCardinality")
? Boolean.parseBoolean(parameters.get("highCardinality"))
: null;
final Long lastUpdateTime = parameters.containsKey("lastUpdateTime") && StringUtils.isNumeric(parameters.get("lastUpdateTime"))
? Long.parseLong(parameters.get("lastUpdateTime")) : null;
final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc");
final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC;
final String sortString = parameters.getOrDefault("sortString", "name.keyword");
final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20;
final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0;
final Boolean running = parameters.containsKey("running") ? Boolean.parseBoolean(parameters.get("running")) : null;
final Boolean disabled = parameters.containsKey("disabled") ? Boolean.parseBoolean(parameters.get("disabled")) : null;
final Boolean failed = parameters.containsKey("failed") ? Boolean.parseBoolean(parameters.get("failed")) : null;

List<QueryBuilder> mustList = new ArrayList<QueryBuilder>();
if (detectorName != null) {
mustList.add(new TermQueryBuilder("name.keyword", detectorName));
}
if (detectorNamePattern != null) {
mustList.add(new WildcardQueryBuilder("name.keyword", detectorNamePattern));
}
if (indices != null) {
mustList.add(new TermQueryBuilder("indices", indices));
}
if (highCardinality != null) {
mustList.add(new TermQueryBuilder("detector_type", highCardinality ? "MULTI_ENTITY" : "SINGLE_ENTITY"));
}
if (lastUpdateTime != null) {
mustList.add(new BoolQueryBuilder().filter(new RangeQueryBuilder("last_update_time").gte(lastUpdateTime)));

}

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must().addAll(mustList);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(boolQueryBuilder)
.size(size)
.from(startIndex)
.sort(sortString, sortOrder);

SearchRequest searchDetectorRequest = new SearchRequest().source(searchSourceBuilder);

if (running != null || disabled != null || failed != null) {
// TODO: add a listener to trigger when the first response is received, to trigger the profile API call
// to fetch the detector state, etc.
// Will need AD client to onboard the profile API first.
}

ActionListener<SearchResponse> searchDetectorListener = ActionListener.<SearchResponse>wrap(response -> {
StringBuilder sb = new StringBuilder();
SearchHit[] hits = response.getHits().getHits();
sb.append("AnomalyDetectors=[");
for (SearchHit hit : hits) {
sb.append("{");
sb.append("id=").append(hit.getId()).append(",");
sb.append("name=").append(hit.getSourceAsMap().get("name"));
sb.append("}");
}
sb.append("]");
sb.append("TotalAnomalyDetectors=").append(response.getHits().getTotalHits().value);
listener.onResponse((T) sb.toString());
}, e -> { listener.onFailure(e); });

adClient.searchAnomalyDetectors(searchDetectorRequest, searchDetectorListener);
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
}

@Override
public String getType() {
return TYPE;
}

/**
* Factory for the {@link SearchAnomalyDetectorsTool}
*/
public static class Factory implements Tool.Factory<SearchAnomalyDetectorsTool> {
private Client client;

private AnomalyDetectionNodeClient adClient;

private static Factory INSTANCE;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchAnomalyDetectorsTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

/**
* Initialize this factory
* @param client The OpenSearch client
*/
public void init(Client client) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
}

@Override
public SearchAnomalyDetectorsTool create(Map<String, Object> map) {
return new SearchAnomalyDetectorsTool(client);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}

}
Loading

0 comments on commit 9c08646

Please sign in to comment.