diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index 94d62ef5..fc0e3b80 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -7,7 +7,6 @@ import static org.apache.commons.text.StringEscapeUtils.unescapeJson; -import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -23,10 +22,9 @@ import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; import org.opensearch.agent.tools.utils.ClusterStatsUtil; import org.opensearch.client.Client; -import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.cluster.routing.allocation.NodeAllocationResult; +import org.opensearch.cluster.routing.allocation.decider.Decision; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -86,9 +84,12 @@ public boolean validate(Map parameters) { + "Human: PHENOMENON\n" + "--------------------\n" + "${parameters.phenomenon} \n\n" - + "Human: POTENTIAL CAUSES AND RESPONSE\n" + + "Human: POTENTIAL CAUSES\n" + "--------------------\n" + "${parameters.causes} \n\n" + + "Human: API RESPONSES\n" + + "${parameters.responses} \n\n" + + "--------------------\n" + "Assistant: "; @SuppressWarnings("unchecked") @@ -101,9 +102,16 @@ public void runOption1(Map parameters, ActionListener lis ActionListener> apiListener = new ActionListener<>() { @Override public void onResponse(Map apiToResponse) { - causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get(API_URL_FIELD)))); Map LLMParams = new java.util.HashMap<>( - Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes)) + Map + .of( + "phenomenon", + (String) knowledgeBase.get("phenomenon"), + "causes", + StringUtils.gson.toJson(causes), + "responses", + StringUtils.gson.toJson(apiToResponse) + ) ); StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); String finalToolPrompt = substitute.replace(TOOL_PROMPT); @@ -200,15 +208,19 @@ private void invokeAPI(String url, Map parameters, ActionListene ActionListener apiListener = new ActionListener<>() { @Override public void onResponse(ClusterAllocationExplainResponse allocationExplainResponse) { - try { - XContentBuilder xContentBuilder = allocationExplainResponse - .getExplanation() - .toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - listener.onResponse(xContentBuilder.toString()); - } catch (IOException e) { - log.error("Failed to invoke api _cluster/allocation/explain due to exception:", e); - listener.onFailure(e); + List nodeDecisions = allocationExplainResponse + .getExplanation() + .getShardAllocationDecision() + .getAllocateDecision() + .getNodeDecisions(); + StringBuilder stringBuilder = new StringBuilder(); + for (NodeAllocationResult nodeDecision : nodeDecisions) { + List decisions = nodeDecision.getCanAllocateDecision().getDecisions(); + for (Decision decision : decisions) { + stringBuilder.append(decision.getExplanation()); + } } + listener.onResponse(stringBuilder.toString()); } @Override