Skip to content

Commit

Permalink
add ut and parser (opensearch-project#64)
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
xinyual authored and yuye-aws committed Apr 26, 2024
1 parent 4fb8785 commit 964b2f0
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 23 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ dependencies {
compileOnly group: 'org.json', name: 'json', version: '20231013'
compileOnly("com.google.guava:guava:32.1.3-jre")
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'

// Plugin dependencies
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
Expand Down
82 changes: 59 additions & 23 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

package org.opensearch.agent.tools;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.json.JSONObject;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
Expand All @@ -28,9 +29,6 @@
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -58,6 +56,8 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
@Setter
@Getter
@ToolAnnotation(PPLTool.TYPE)
public class PPLTool implements Tool {

Expand Down Expand Up @@ -93,6 +93,9 @@ public PPLTool(Client client, String modelId, String contextPrompt) {
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
String indexName = parameters.get("index");
String question = parameters.get("question");
if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) {
throw new IllegalArgumentException("Parameter index and question can not be null or empty.");
}
SearchRequest searchRequest = buildSearchRequest(indexName);
GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName);
client.admin().indices().getMappings(getMappingsRequest, ActionListener.<GetMappingsResponse>wrap(getMappingsResponse -> {
Expand All @@ -114,7 +117,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0);
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
Map<String, String> dataAsMap = (Map<String, String>) modelTensor.getDataAsMap();
String ppl = dataAsMap.get("output");
String ppl = parseOutput(dataAsMap.get("response"), indexName);
JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl));
PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc");
TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest);
Expand Down Expand Up @@ -226,6 +229,8 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
Map<String, String> fieldsToType = new HashMap<>();
extractNamesTypes(mappingSource, fieldsToType, "");
StringJoiner tableInfoJoiner = new StringJoiner("\n");
List<String> sortedKeys = new ArrayList<>(fieldsToType.keySet());
Collections.sort(sortedKeys);

if (searchHits.length > 0) {
SearchHit hit = searchHits[0];
Expand All @@ -236,12 +241,12 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
}
extractSamples(sampleSource, fieldsToSample, "");

for (String key : fieldsToType.keySet()) {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
}
} else {
for (String key : fieldsToType.keySet()) {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
}
Expand All @@ -252,7 +257,10 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
}

private String constructPrompt(String tableInfo, String question, String indexName) {
return String.format(Locale.getDefault(), contextPrompt, question.strip(), indexName, tableInfo.strip());
Map<String, String> indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName);
StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}");
String finalPrompt = substitutor.replace(contextPrompt);
return finalPrompt;
}

private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
Expand Down Expand Up @@ -297,22 +305,50 @@ private static void extractSamples(Map<String, Object> sampleSource, Map<String,
}

private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> { listener.onResponse(fromActionResponse(r)); }, listener::onFailure);
return ActionListener.wrap(r -> { listener.onResponse(TransportPPLQueryResponse.fromActionResponse(r)); }, listener::onFailure);
}

private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof TransportPPLQueryResponse) {
return (TransportPPLQueryResponse) actionResponse;
private Map<String, String> extractFromChatParameters(Map<String, String> parameters) {
if (parameters.containsKey("input")) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
parameters.putAll(chatParameters);
} finally {
return parameters;
}
}
return parameters;
}

private String parseOutput(String llmOutput, String indexName) {
String ppl;
Pattern pattern = Pattern.compile("<ppl>((.|[\\r\\n])+?)</ppl>"); // For ppl like <ppl> source=a \n | fields b </ppl>
Matcher matcher = pattern.matcher(llmOutput);

if (matcher.find()) {
ppl = matcher.group(1).replaceAll("[\\r\\n]", "").replaceAll("ISNOTNULL", "isnotnull").trim();
} else { // logic for only ppl returned
int sourceIndex = llmOutput.indexOf("source=");
if (sourceIndex != -1) {
llmOutput = llmOutput.substring(sourceIndex);

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new TransportPPLQueryResponse(input);
// Splitting the string at "|"
String[] lists = llmOutput.split("\\|");

// Modifying the first element
if (lists.length > 0) {
lists[0] = "source=" + indexName;
}

// Joining the string back together
ppl = String.join("|", lists);
} else {
throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format");
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e);
}

ppl = ppl.replace("`", "");
ppl = ppl.replaceAll("\\bSPAN\\(", "span(");
return ppl;
}

}
Loading

0 comments on commit 964b2f0

Please sign in to comment.