diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index 9730f58a..04d2f00d 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -5,19 +5,23 @@ package org.opensearch.agent.tools; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.text.StringSubstitutor; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest; -import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplanation; +import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.client.Client; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; @@ -140,63 +144,86 @@ public void run(Map parameters, ActionListener listener) String knowledge = parameters.getOrDefault("knowledge", mocked_knowledge); Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); List> causes = (List>) knowledgeBase.get("potential_causes"); - Map apiToResponse = causes - .stream() - .map(c -> c.get("API_URL")) - .distinct() - .collect(Collectors.toMap(url -> url, url -> invokeAPI(url, parameters))); - causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get("API_URL")))); - Map LLMParams = new java.util.HashMap<>( - Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes)) + List apiList = causes.stream().map(c -> c.get("API_URL")).distinct().collect(Collectors.toList()); + final GroupedActionListener> groupedListener = new GroupedActionListener<>( + ActionListener.wrap(responses -> { + Map apiToResponse = responses + .stream() + .map(this::extractResponse) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get("API_URL")))); + Map LLMParams = new java.util.HashMap<>( + Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes)) + ); + StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); + String finalToolPrompt = substitute.replace(TOOL_PROMPT); + log.error("qh finalToolPrompt: " + finalToolPrompt); + LLMParams.put("prompt", finalToolPrompt); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + Map dataMap = Optional + .ofNullable(modelTensorOutput.getMlModelOutputs()) + .flatMap(outputs -> outputs.stream().findFirst()) + .flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()) + .map(ModelTensor::getDataAsMap) + .orElse(null); + if (dataMap == null) { + throw new IllegalArgumentException("No dataMap returned from LLM."); + } + listener.onResponse((T) dataMap.get("response")); + }, listener::onFailure)); + }, listener::onFailure), + apiList.size() ); - StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); - String finalToolPrompt = substitute.replace(TOOL_PROMPT); - log.error("qh finalToolPrompt: " + finalToolPrompt); - LLMParams.put("prompt", finalToolPrompt); - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build(); - ActionRequest request = new MLPredictionTaskRequest( - modelId, - MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); - Map dataMap = Optional - .ofNullable(modelTensorOutput.getMlModelOutputs()) - .flatMap(outputs -> outputs.stream().findFirst()) - .flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()) - .map(ModelTensor::getDataAsMap) - .orElse(null); - if (dataMap == null) { - throw new IllegalArgumentException("No dataMap returned from LLM."); - } - listener.onResponse((T) dataMap.get("response")); - }, listener::onFailure)); + apiList.forEach(api -> invokeAPI(api, parameters, groupedListener)); } - private String invokeAPI(String url, Map parameters) { + private void invokeAPI( + String url, + Map parameters, + GroupedActionListener> groupedListener + ) { switch (url) { case "_cluster/allocation/explain": ClusterAllocationExplainRequest request = new ClusterAllocationExplainRequest(); request.setIndex(parameters.get("index")); request.setPrimary(parameters.getOrDefault("alert_type", "").equals("CLUSTER_RED")); request.setShard(0); - try { - // TODO: need to be optimized to use listener to avoid block wait - ClusterAllocationExplanation clusterAllocationExplanation = client - .admin() - .cluster() - .allocationExplain(request) - .get() - .getExplanation(); - XContentBuilder xContentBuilder = clusterAllocationExplanation - .toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - return xContentBuilder.toString(); - } catch (Exception e) { - return ""; + client + .admin() + .cluster() + .allocationExplain( + request, + ActionListener + .wrap(r -> groupedListener.onResponse(Pair.of("_cluster/allocation/explain", r)), groupedListener::onFailure) + ); + + default: + } + } + + private Pair extractResponse(final Pair response) { + switch (response.getKey()) { + case "_cluster/allocation/explain": + if (response.getValue() instanceof ClusterAllocationExplainResponse) { + ClusterAllocationExplainResponse explainResponse = (ClusterAllocationExplainResponse) response.getValue(); + try { + XContentBuilder xContentBuilder = explainResponse + .getExplanation() + .toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + return Pair.of(response.getKey(), xContentBuilder.toString()); + } catch (IOException e) { + throw new RuntimeException(e); + } } default: - return ""; + return Pair.of("", ""); } }