Skip to content

Commit

Permalink
Add default prompt to ppl tool (#125)
Browse files Browse the repository at this point in the history
* add default prompt for ppl tool

Signed-off-by: xinyual <[email protected]>

* fix Upper problem

Signed-off-by: xinyual <[email protected]>

* change wrong information

Signed-off-by: xinyual <[email protected]>

* remove uesless log

Signed-off-by: xinyual <[email protected]>

* add corresponding UTs

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* use locale instead

Signed-off-by: xinyual <[email protected]>

* move dict to static

Signed-off-by: xinyual <[email protected]>

* move dict to static

Signed-off-by: xinyual <[email protected]>

* replace throw error with error log

Signed-off-by: xinyual <[email protected]>

* add default value for PPL model type

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Jan 16, 2024
1 parent c8b6898 commit 7e4c8d5
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
62 changes: 59 additions & 3 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@

import static org.opensearch.ml.common.CommonValue.*;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.StringJoiner;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -83,12 +87,48 @@ public class PPLTool implements Tool {

private String contextPrompt;

private PPLModelType pplModelType;

private static Gson gson = new Gson();

public PPLTool(Client client, String modelId, String contextPrompt) {
private static Map<String, String> defaultPromptDict;

static {
try {
defaultPromptDict = loadDefaultPromptDict();
} catch (IOException e) {
log.error("fail to load default prompt dict" + e.getMessage());
defaultPromptDict = new HashMap<>();
}
}

public enum PPLModelType {
CLAUDE,
FINETUNE;

public static PPLModelType from(String value) {
if (value.isEmpty()) {
return PPLModelType.CLAUDE;
}
try {
return PPLModelType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
log.error("Wrong PPL Model type, should be CLAUDE or FINETUNE");
return PPLModelType.CLAUDE;
}
}

}

public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType) {
this.client = client;
this.modelId = modelId;
this.contextPrompt = contextPrompt;
this.pplModelType = PPLModelType.from(pplModelType);
if (contextPrompt.isEmpty()) {
this.contextPrompt = this.defaultPromptDict.getOrDefault(this.pplModelType.toString(), "");
} else {
this.contextPrompt = contextPrompt;
}
}

@Override
Expand Down Expand Up @@ -208,7 +248,12 @@ public void init(Client client) {

@Override
public PPLTool create(Map<String, Object> map) {
return new PPLTool(client, (String) map.get("model_id"), (String) map.get("prompt"));
return new PPLTool(
client,
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", "")
);
}

@Override
Expand All @@ -225,6 +270,7 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

}

private SearchRequest buildSearchRequest(String indexName) {
Expand Down Expand Up @@ -373,4 +419,14 @@ private String parseOutput(String llmOutput, String indexName) {
return ppl;
}

private static Map<String, String> loadDefaultPromptDict() throws IOException {
InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json");
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
Map<String, String> defaultPromptDict = gson.fromJson(defaultPromptContent, Map.class);
return defaultPromptDict;
}
return new HashMap<>();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"CLAUDE": "\n\nHuman:You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=\\`<index>\\` | where \\`<field>\\` = '\\`<value>\\`'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n\\`\\`\\`\n- field_name: field_type (sample field value)\n\\`\\`\\`\n\nFor example, below is a field called \\`timestamp\\`, it has a field type of \\`date\\`, and a sample value of it could look like \\`1686000665919\\`.\n\\`\\`\\`\n- timestamp: date (1686000665919)\n\\`\\`\\`\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('[email protected]')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=\\`accounts\\` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort +age | head 5 | fields \\`firstname\\`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=\\`accounts\\` | fields \\`address\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie' OR \\`lastname\\` = 'frank' | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=\\`accounts\\` | where \\`firstname\\` != 'Hattie' AND \\`lastname\\` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=\\`accounts\\` | where QUERY_STRING(['email'], '.com') | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=\\`accounts\\` | where ISNOTNULL(\\`email\\`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=\\`accounts\\` | where \\`firstname\\` ='Amber' | stats COUNT() AS \\`count\\`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=\\`accounts\\` | where \\`age\\` > 33 | stats COUNT() AS \\`count\\`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=\\`accounts\\` | stats DISTINCT_COUNT(age) AS \\`distinct_count\\`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\` BY \\`gender\\`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=\\`accounts\\` | stats AVG(\\`age\\`) AS \\`avg_age\\`, MIN(\\`age\\`) AS \\`min_age\\`, MAX(\\`age\\`) AS \\`max_age\\`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=\\`accounts\\` | stats AVG(\\`balance\\`) AS \\`avg_balance\\` BY \\`state\\` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('[email protected]')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'clothing') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\` by SPAN(\\`order_date\\`, 2h) AS \\`span\\`, \\`geoip.city_name\\`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'shoes') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(\\`taxful_total_price\\`) AS \\`revenue\\` by SPAN(\\`order_date\\`, 1d) AS \\`span\\`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '!200') AND \\`observerTime\\` >= '2023-03-01 00:00:00' AND \\`observerTime\\` < '2023-04-01 00:00:00' | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=\\`events\\` | where \\`category\\` = 'web' AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(\\`observerTime\\`) >= 2 AND DAY_OF_WEEK(\\`observerTime\\`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(\\`observerTime\\`, 'yyyy-MM-dd')) AS \\`distinct_count\\`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=\\`events\\` | stats SUM(\\`http.response.bytes\\`) AS \\`sum_bytes\\` by \\`trace_id\\` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=\\`events\\` | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a \\`text\\` or \\`keyword\\` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type \\`date\\` and not \\`long\\`.\n#02 You must pick a field with \\`date\\` type when filtering on date/time.\n#03 You must pick a field with \\`date\\` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of \\`log\\`, \\`body\\`, \\`message\\`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where \\`timestamp\\` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where \\`timestamp\\` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where \\`timestamp\\` < '2023-01-01 00:00:00''. Do not use \\`DATE_FORMAT()\\`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(\\`<field>\\`, <interval>)' must have type \\`date\\`, not \\`long\\`.\n#05 When aggregating by \\`SPAN\\` and another field, put \\`SPAN\\` after \\`by\\` and before the other field, eg. 'stats COUNT() AS \\`count\\` by SPAN(\\`timestamp\\`, 1d) AS \\`span\\`, \\`category\\`'.\n#06 You must put values in quotes when filtering fields with \\`text\\` or \\`keyword\\` field type.\n#07 To find documents that contain certain phrases in string fields, use \\`QUERY_STRING\\` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. \\`integer\\`), then use 'where \\`status_code\\` >= 400'; if the field is a string (eg. \\`text\\` or \\`keyword\\`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nPlease only contain PPL inside your response.\n----------------\nQuestion: ${indexInfo.question}? index is \\`${indexInfo.indexName}\\`\nFields:\n${indexInfo.mappingInfo}\n\nAssistant:",
"FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question} Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n"
}
Loading

0 comments on commit 7e4c8d5

Please sign in to comment.