From c0c9e3854437987bcc9dd72cc75b4e3bdbbeb8a2 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Fri, 26 Jan 2024 09:01:26 +0800 Subject: [PATCH] add execute field (#146) (#147) * add execute field * apply spotless --------- (cherry picked from commit 16a26ce06ffefc44010ccbe67f32ba09c8bb4e8c) Signed-off-by: xinyual Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Signed-off-by: yuye-aws --- .../java/org/opensearch/agent/tools/PPLTool.java | 12 ++++++++++-- .../org/opensearch/agent/tools/PPLToolTests.java | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index cec1114d..c008d523 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -87,6 +87,8 @@ public class PPLTool implements Tool { private String contextPrompt; + private Boolean execute; + private PPLModelType pplModelType; private static Gson gson = new Gson(); @@ -120,7 +122,7 @@ public static PPLModelType from(String value) { } - public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType) { + public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType, boolean execute) { this.client = client; this.modelId = modelId; this.pplModelType = PPLModelType.from(pplModelType); @@ -129,6 +131,7 @@ public PPLTool(Client client, String modelId, String contextPrompt, String pplMo } else { this.contextPrompt = contextPrompt; } + this.execute = execute; } @Override @@ -169,6 +172,10 @@ public void run(Map parameters, ActionListener listener) ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); Map dataAsMap = (Map) modelTensor.getDataAsMap(); String ppl = parseOutput(dataAsMap.get("response"), indexName); + if (!this.execute) { + listener.onResponse((T) ppl); + return; + } JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); @@ -253,7 +260,8 @@ public PPLTool create(Map map) { client, (String) map.get("model_id"), (String) map.getOrDefault("prompt", ""), - (String) map.getOrDefault("model_type", "") + (String) map.getOrDefault("model_type", ""), + (boolean) map.getOrDefault("execute", true) ); } diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index 680586d0..129c2411 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -138,6 +138,19 @@ public void testTool() { } + @Test + public void testTool_with_WithoutExecution() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "claude", "execute", false)); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + assertEquals("source=demo| head 1", executePPLResult); + }, e -> { log.info(e); })); + + } + @Test public void testTool_with_DefaultPrompt() { PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude"));