Skip to content

Commit

Permalink
fixPPLAllowedFields
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Feb 4, 2024
1 parent 940ac32 commit 6bb08f0
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -95,7 +97,28 @@ public class PPLTool implements Tool {

private static Map<String, String> defaultPromptDict;

private static Set<String> allowedFieldsType;

static {
allowedFieldsType = new HashSet<>(); // from https://github.com/opensearch-project/sql/blob/2.x/docs/user/ppl/general/datatypes.rst#data-types-mapping
allowedFieldsType.add("boolean");
allowedFieldsType.add("byte");
allowedFieldsType.add("short");
allowedFieldsType.add("integer");
allowedFieldsType.add("long");
allowedFieldsType.add("float");
allowedFieldsType.add("half_float");
allowedFieldsType.add("scaled_float");
allowedFieldsType.add("double");
allowedFieldsType.add("keyword");
allowedFieldsType.add("text");
allowedFieldsType.add("date");
allowedFieldsType.add("ip");
allowedFieldsType.add("binary");
allowedFieldsType.add("object");
allowedFieldsType.add("nested");


try {
defaultPromptDict = loadDefaultPromptDict();
} catch (IOException e) {
Expand Down Expand Up @@ -148,13 +171,15 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
+ indexName
);
}
SearchRequest searchRequest = buildSearchRequest(indexName);

GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName);
client.admin().indices().getMappings(getMappingsRequest, ActionListener.<GetMappingsResponse>wrap(getMappingsResponse -> {
Map<String, MappingMetadata> mappings = getMappingsResponse.getMappings();
if (mappings.size() == 0) {
throw new IllegalArgumentException("No matching mapping with index name: " + indexName);
}
String firstIndexName = (String) mappings.keySet().toArray()[0];
SearchRequest searchRequest = buildSearchRequest(firstIndexName);
client.search(searchRequest, ActionListener.<SearchResponse>wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
String tableInfo = constructTableInfo(searchHits, mappings);
Expand Down Expand Up @@ -320,13 +345,17 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
extractSamples(sampleSource, fieldsToSample, "");

for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
if (allowedFieldsType.contains(fieldsToType.get(key))) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
}
}
} else {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
if (allowedFieldsType.contains(fieldsToType.get(key))) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
}
}
}

Expand Down

0 comments on commit 6bb08f0

Please sign in to comment.