diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 823320f0..fb3d63f7 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -15,6 +15,7 @@ import org.opensearch.agent.tools.CreateAnomalyDetectorTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; +import org.opensearch.agent.tools.PainlessTool; import org.opensearch.agent.tools.RAGTool; import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; @@ -77,6 +78,7 @@ public Collection createComponents( SearchMonitorsTool.Factory.getInstance().init(client); CreateAlertTool.Factory.getInstance().init(client); CreateAnomalyDetectorTool.Factory.getInstance().init(client); + PainlessTool.Factory.getInstance().init(scriptService); return Collections.emptyList(); } @@ -93,7 +95,8 @@ public List> getToolFactories() { SearchAnomalyResultsTool.Factory.getInstance(), SearchMonitorsTool.Factory.getInstance(), CreateAlertTool.Factory.getInstance(), - CreateAnomalyDetectorTool.Factory.getInstance() + CreateAnomalyDetectorTool.Factory.getInstance(), + PainlessTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/PainlessTool.java b/src/main/java/org/opensearch/agent/tools/PainlessTool.java new file mode 100644 index 00000000..4f04fad9 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/PainlessTool.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import com.google.gson.Gson; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * use case for this tool will only focus on flow agent + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(PainlessTool.TYPE) +public class PainlessTool implements Tool { + public static final String TYPE = "PainlessTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to execute painless script"; + + @Setter + @Getter + private String name = TYPE; + + @Getter + private String type = TYPE; + + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private ScriptService scriptService; + @Setter + private String scriptCode; + + public PainlessTool(ScriptService scriptEngine, String script) { + this.scriptService = scriptEngine; + this.scriptCode = script; + } + + private Gson gson = new Gson(); + + @Override + public void run(Map parameters, ActionListener listener) { + Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap()); + Map flattenedParameters = new HashMap<>(); + for (Map.Entry entry : parameters.entrySet()) { + // keep original values and flatten + flattenedParameters.put(entry.getKey(), entry.getValue()); + // TODO default is json parser. we may support format + try { + String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue()); + Map map = gson.fromJson(value, Map.class); + flattenMap(map, flattenedParameters, entry.getKey()); + } catch (Throwable ignored) {} + } + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters); + try { + String result = templateScript.execute(); + listener.onResponse(result == null ? (T) "" : (T) result); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void flattenMap(Map map, Map flatMap, String prefix) { + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + if (prefix != null && !prefix.isEmpty()) { + key = prefix + "." + entry.getKey(); + } + Object value = entry.getValue(); + if (value instanceof Map) { + flattenMap((Map) value, flatMap, key); + } else { + flatMap.put(key, value); + } + } + } + + @Override + public boolean validate(Map map) { + return true; + } + + public static class Factory implements Tool.Factory { + private ScriptService scriptService; + + private static PainlessTool.Factory INSTANCE; + + public static PainlessTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (PainlessTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new PainlessTool.Factory(); + return INSTANCE; + } + } + + public void init(ScriptService scriptService) { + this.scriptService = scriptService; + } + + @Override + public PainlessTool create(Map map) { + String script = (String) map.get("script"); + // TODO add script non null/empty check + return new PainlessTool(scriptService, script); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + } +} diff --git a/src/test/java/org/opensearch/integTest/PainlessToolIT.java b/src/test/java/org/opensearch/integTest/PainlessToolIT.java new file mode 100644 index 00000000..84659d12 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PainlessToolIT.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.Assert; +import org.junit.Before; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class PainlessToolIT extends BaseAgentToolsIT { + + private String registerAgentRequestBody; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files + .readString( + Path.of(this.getClass().getClassLoader().getResource("org/opensearch/agent/tools/register_painless_agent.json").toURI()) + ); + } + + public void test_execute() { + String script = "def x = new HashMap(); x.abc = '5'; return x.abc;"; + String agentRequestBody = registerAgentRequestBody.replaceAll("