Skip to content

Commit

Permalink
optimize block wait by using GroupedActionListener
Browse files Browse the repository at this point in the history
  • Loading branch information
qianheng-aws committed Aug 5, 2024
1 parent e57da52 commit 352b353
Showing 1 changed file with 73 additions and 46 deletions.
119 changes: 73 additions & 46 deletions src/main/java/org/opensearch/agent/tools/RCATool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@

package org.opensearch.agent.tools;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest;
import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplanation;
import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -140,63 +144,86 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
String knowledge = parameters.getOrDefault("knowledge", mocked_knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("potential_causes");
Map<String, String> apiToResponse = causes
.stream()
.map(c -> c.get("API_URL"))
.distinct()
.collect(Collectors.toMap(url -> url, url -> invokeAPI(url, parameters)));
causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get("API_URL"))));
Map<String, String> LLMParams = new java.util.HashMap<>(
Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes))
List<String> apiList = causes.stream().map(c -> c.get("API_URL")).distinct().collect(Collectors.toList());
final GroupedActionListener<Pair<String, ActionResponse>> groupedListener = new GroupedActionListener<>(
ActionListener.wrap(responses -> {
Map<String, String> apiToResponse = responses
.stream()
.map(this::extractResponse)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get("API_URL"))));
Map<String, String> LLMParams = new java.util.HashMap<>(
Map.of("phenomenon", (String) knowledgeBase.get("phenomenon"), "causes", StringUtils.gson.toJson(causes))
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
log.error("qh finalToolPrompt: " + finalToolPrompt);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("response"));
}, listener::onFailure));
}, listener::onFailure),
apiList.size()
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
log.error("qh finalToolPrompt: " + finalToolPrompt);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("response"));
}, listener::onFailure));
apiList.forEach(api -> invokeAPI(api, parameters, groupedListener));
}

private String invokeAPI(String url, Map<String, String> parameters) {
private void invokeAPI(
String url,
Map<String, String> parameters,
GroupedActionListener<Pair<String, ActionResponse>> groupedListener
) {
switch (url) {
case "_cluster/allocation/explain":
ClusterAllocationExplainRequest request = new ClusterAllocationExplainRequest();
request.setIndex(parameters.get("index"));
request.setPrimary(parameters.getOrDefault("alert_type", "").equals("CLUSTER_RED"));
request.setShard(0);
try {
// TODO: need to be optimized to use listener to avoid block wait
ClusterAllocationExplanation clusterAllocationExplanation = client
.admin()
.cluster()
.allocationExplain(request)
.get()
.getExplanation();
XContentBuilder xContentBuilder = clusterAllocationExplanation
.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
return xContentBuilder.toString();
} catch (Exception e) {
return "";
client
.admin()
.cluster()
.allocationExplain(
request,
ActionListener
.wrap(r -> groupedListener.onResponse(Pair.of("_cluster/allocation/explain", r)), groupedListener::onFailure)
);

default:
}
}

private Pair<String, String> extractResponse(final Pair<String, ? extends ActionResponse> response) {
switch (response.getKey()) {
case "_cluster/allocation/explain":
if (response.getValue() instanceof ClusterAllocationExplainResponse) {
ClusterAllocationExplainResponse explainResponse = (ClusterAllocationExplainResponse) response.getValue();
try {
XContentBuilder xContentBuilder = explainResponse
.getExplanation()
.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
return Pair.of(response.getKey(), xContentBuilder.toString());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

default:
return "";
return Pair.of("", "");
}
}

Expand Down

0 comments on commit 352b353

Please sign in to comment.