Skip to content

Commit

Permalink
Add UT and parser for ppl tool (#57)
Browse files Browse the repository at this point in the history
* add output parser

Signed-off-by: xinyual <[email protected]>

* add unit test

Signed-off-by: xinyual <[email protected]>

* remove useless change

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* add ppl tag example

Signed-off-by: xinyual <[email protected]>

* add UT for reg expression and add logic 1. sort field 2.using substitute

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* remove redundant import

Signed-off-by: xinyual <[email protected]>

* change gradle dependency

Signed-off-by: xinyual <[email protected]>

* add illegeal exception

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Dec 26, 2023
1 parent 37bbd3b commit 30e5255
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 29 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ dependencies {
compileOnly group: 'org.json', name: 'json', version: '20231013'
compileOnly("com.google.guava:guava:33.0.0-jre")
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'

// Plugin dependencies
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
Expand Down
82 changes: 53 additions & 29 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

package org.opensearch.agent.tools;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.json.JSONObject;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
Expand All @@ -28,9 +29,6 @@
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -58,6 +56,8 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
@Setter
@Getter
@ToolAnnotation(PPLTool.TYPE)
public class PPLTool implements Tool {

Expand Down Expand Up @@ -94,6 +94,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
parameters = extractFromChatParameters(parameters);
String indexName = parameters.get("index");
String question = parameters.get("question");
if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) {
throw new IllegalArgumentException("Parameter index and question can not be null or empty.");
}
SearchRequest searchRequest = buildSearchRequest(indexName);
GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName);
client.admin().indices().getMappings(getMappingsRequest, ActionListener.<GetMappingsResponse>wrap(getMappingsResponse -> {
Expand All @@ -116,7 +119,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0);
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
Map<String, String> dataAsMap = (Map<String, String>) modelTensor.getDataAsMap();
String ppl = dataAsMap.get("response");
String ppl = parseOutput(dataAsMap.get("response"), indexName);
JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl));
PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc");
TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest);
Expand Down Expand Up @@ -228,6 +231,8 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
Map<String, String> fieldsToType = new HashMap<>();
extractNamesTypes(mappingSource, fieldsToType, "");
StringJoiner tableInfoJoiner = new StringJoiner("\n");
List<String> sortedKeys = new ArrayList<>(fieldsToType.keySet());
Collections.sort(sortedKeys);

if (searchHits.length > 0) {
SearchHit hit = searchHits[0];
Expand All @@ -238,12 +243,12 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
}
extractSamples(sampleSource, fieldsToSample, "");

for (String key : fieldsToType.keySet()) {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
}
} else {
for (String key : fieldsToType.keySet()) {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
}
Expand All @@ -254,7 +259,10 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
}

private String constructPrompt(String tableInfo, String question, String indexName) {
return String.format(Locale.getDefault(), contextPrompt, question.strip(), indexName, tableInfo.strip());
Map<String, String> indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName);
StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}");
String finalPrompt = substitutor.replace(contextPrompt);
return finalPrompt;
}

private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
Expand Down Expand Up @@ -299,23 +307,7 @@ private static void extractSamples(Map<String, Object> sampleSource, Map<String,
}

private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> { listener.onResponse(fromActionResponse(r)); }, listener::onFailure);
}

private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof TransportPPLQueryResponse) {
return (TransportPPLQueryResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new TransportPPLQueryResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e);
}

return ActionListener.wrap(r -> { listener.onResponse(TransportPPLQueryResponse.fromActionResponse(r)); }, listener::onFailure);
}

private Map<String, String> extractFromChatParameters(Map<String, String> parameters) {
Expand All @@ -329,4 +321,36 @@ private Map<String, String> extractFromChatParameters(Map<String, String> parame
}
return parameters;
}

private String parseOutput(String llmOutput, String indexName) {
String ppl;
Pattern pattern = Pattern.compile("<ppl>((.|[\\r\\n])+?)</ppl>"); // For ppl like <ppl> source=a \n | fields b </ppl>
Matcher matcher = pattern.matcher(llmOutput);

if (matcher.find()) {
ppl = matcher.group(1).replaceAll("[\\r\\n]", "").replaceAll("ISNOTNULL", "isnotnull").trim();
} else { // logic for only ppl returned
int sourceIndex = llmOutput.indexOf("source=");
if (sourceIndex != -1) {
llmOutput = llmOutput.substring(sourceIndex);

// Splitting the string at "|"
String[] lists = llmOutput.split("\\|");

// Modifying the first element
if (lists.length > 0) {
lists[0] = "source=" + indexName;
}

// Joining the string back together
ppl = String.join("|", lists);
} else {
throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format");
}
}
ppl = ppl.replace("`", "");
ppl = ppl.replaceAll("\\bSPAN\\(", "span(");
return ppl;
}

}
Loading

0 comments on commit 30e5255

Please sign in to comment.