diff --git a/.run/Debug Skills.run.xml b/.run/Debug Skills.run.xml new file mode 100644 index 00000000..1dd23b84 --- /dev/null +++ b/.run/Debug Skills.run.xml @@ -0,0 +1,15 @@ + + + + \ No newline at end of file diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index 94d62ef5..ed662399 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -5,11 +5,8 @@ package org.opensearch.agent.tools; -import static org.apache.commons.text.StringEscapeUtils.unescapeJson; - import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -22,6 +19,7 @@ import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest; import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; import org.opensearch.agent.tools.utils.ClusterStatsUtil; +import org.opensearch.agent.tools.utils.RCADoc; import org.opensearch.client.Client; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -36,7 +34,6 @@ import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.common.utils.StringUtils; import lombok.Getter; import lombok.Setter; @@ -94,16 +91,14 @@ public boolean validate(Map parameters) { @SuppressWarnings("unchecked") public void runOption1(Map parameters, ActionListener listener) { String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD); - knowledge = unescapeJson(knowledge); - Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); - List> causes = (List>) knowledgeBase.get("causes"); - Set apis = causes.stream().map(c -> c.get(API_URL_FIELD)).collect(Collectors.toSet()); + RCADoc rcaDoc = new RCADoc(knowledge); + Set apis = rcaDoc.causes.stream().map(c -> c.apiUrl).collect(Collectors.toSet()); ActionListener> apiListener = new ActionListener<>() { @Override public void onResponse(Map apiToResponse) { - causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get(API_URL_FIELD)))); + rcaDoc.causes.forEach(cause -> cause.response = apiToResponse.get(cause.apiUrl)); Map LLMParams = new java.util.HashMap<>( - Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes)) + Map.of("phenomenon", rcaDoc.phenomenon, "causes", rcaDoc.toString()) ); StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); String finalToolPrompt = substitute.replace(TOOL_PROMPT); diff --git a/src/main/java/org/opensearch/agent/tools/utils/RCACause.java b/src/main/java/org/opensearch/agent/tools/utils/RCACause.java index 3b79aa04..33cfedf0 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/RCACause.java +++ b/src/main/java/org/opensearch/agent/tools/utils/RCACause.java @@ -5,6 +5,10 @@ package org.opensearch.agent.tools.utils; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; + +import java.util.Map; + import lombok.Getter; @Getter @@ -15,4 +19,31 @@ public RCACause() {} public String apiUrl; public String expectedResponse; public String response; + + public static RCACause fromMap(Map map) { + RCACause cause = new RCACause(); + cause.reason = map.getOrDefault("reason", ""); + cause.apiUrl = map.getOrDefault("api_url", ""); + cause.expectedResponse = map.getOrDefault("expected_response", ""); + cause.response = map.getOrDefault("response", ""); + return cause; + } + + @Override + public String toString() { + return "{" + + "\"reason\": \"" + + escapeJson(reason) + + "\"," + + "\"apiUrl\": \"" + + escapeJson(apiUrl) + + "\"," + + "\"expectedResponse\": \"" + + escapeJson(expectedResponse) + + "\"," + + "\"response\": \"" + + escapeJson(response) + + "\"" + + "}"; + } } diff --git a/src/main/java/org/opensearch/agent/tools/utils/RCADoc.java b/src/main/java/org/opensearch/agent/tools/utils/RCADoc.java index a2428e83..fb960e10 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/RCADoc.java +++ b/src/main/java/org/opensearch/agent/tools/utils/RCADoc.java @@ -5,7 +5,14 @@ package org.opensearch.agent.tools.utils; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.apache.commons.text.StringEscapeUtils.unescapeJson; + import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.ml.common.utils.StringUtils; import lombok.Getter; @@ -15,5 +22,31 @@ public class RCADoc { public String phenomenon; public List causes; - public RCADoc() {} + @SuppressWarnings("unchecked") + public RCADoc(String knowledge) { + knowledge = unescapeJson(knowledge); + Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); + this.phenomenon = (String) knowledgeBase.get("phenomenon"); + this.causes = ((List>) knowledgeBase.get("causes")) + .stream() + .map(RCACause::fromMap) + .collect(Collectors.toList()); + } + + @Override + public String toString() { + StringBuilder json = new StringBuilder(); + json.append("{"); + json.append("\"phenomenon\":\"").append(escapeJson(phenomenon)).append("\","); + json.append("\"causes\": ["); + for (int i = 0; i < causes.size(); i++) { + json.append(causes.get(i).toString()); + if (i < causes.size() - 1) { + json.append(", "); + } + } + + json.append("]}"); + return json.toString(); + } }