Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default prompt to ppl tool #125

Merged
merged 12 commits into from
Jan 16, 2024
61 changes: 58 additions & 3 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@

import static org.opensearch.ml.common.CommonValue.*;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.StringJoiner;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -83,12 +87,46 @@

private String contextPrompt;

private PPLModelType pplModelType;

private static Gson gson = new Gson();

public PPLTool(Client client, String modelId, String contextPrompt) {
private static Map<String, String> defaultPromptDict;

static {
try {
defaultPromptDict = loadDefaultPromptDict();
} catch (IOException e) {
throw new RuntimeException("fail to load default prompt dict" + e.getMessage());

Check warning on line 100 in src/main/java/org/opensearch/agent/tools/PPLTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/PPLTool.java#L99-L100

Added lines #L99 - L100 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that this would break the plugin initialization process with this error thrown?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Already change the logic to init the dict to an empty dict and log an error here.

}
}

private enum PPLModelType {
CLAUDE,
FINETUNE;

public static PPLModelType from(String value) {
if (value.isEmpty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is value possibly be null?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, when create ppltool object, we use (String) map.getOrDefault("model_type", "") So model type string will have default value.

return PPLModelType.CLAUDE;
}
try {
return PPLModelType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong PPL Model type, should be CLAUDE or FINETUNE");
}
}

}

public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType) {
this.client = client;
this.modelId = modelId;
this.contextPrompt = contextPrompt;
this.pplModelType = PPLModelType.from(pplModelType);
if (contextPrompt.isEmpty()) {
this.contextPrompt = this.defaultPromptDict.getOrDefault(this.pplModelType.toString(), "");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible this pplModelType.toString() produce NPE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we init the modeltype with the default value, the model type should be an object or throw an error. So it shouldn't be a NPE.

} else {
this.contextPrompt = contextPrompt;
}
}

@Override
Expand Down Expand Up @@ -208,7 +246,12 @@

@Override
public PPLTool create(Map<String, Object> map) {
return new PPLTool(client, (String) map.get("model_id"), (String) map.get("prompt"));
return new PPLTool(
client,
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", "")
);
}

@Override
Expand All @@ -225,6 +268,7 @@
public String getDefaultVersion() {
return null;
}

}

private SearchRequest buildSearchRequest(String indexName) {
Expand Down Expand Up @@ -373,4 +417,15 @@
return ppl;
}

private static Map<String, String> loadDefaultPromptDict() throws IOException {
InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json");
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
log.info(defaultPromptContent);
Map<String, String> defaultPromptDict = gson.fromJson(defaultPromptContent, Map.class);
return defaultPromptDict;
}
return new HashMap<>();

Check warning on line 428 in src/main/java/org/opensearch/agent/tools/PPLTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/PPLTool.java#L428

Added line #L428 was not covered by tests
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"CLAUDE": "\n\nHuman:You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=\\`<index>\\` | where \\`<field>\\` = '\\`<value>\\`'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n\\`\\`\\`\n- field_name: field_type (sample field value)\n\\`\\`\\`\n\nFor example, below is a field called \\`timestamp\\`, it has a field type of \\`date\\`, and a sample value of it could look like \\`1686000665919\\`.\n\\`\\`\\`\n- timestamp: date (1686000665919)\n\\`\\`\\`\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('[email protected]')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=\\`accounts\\` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort +age | head 5 | fields \\`firstname\\`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=\\`accounts\\` | fields \\`address\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie' OR \\`lastname\\` = 'frank' | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=\\`accounts\\` | where \\`firstname\\` != 'Hattie' AND \\`lastname\\` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=\\`accounts\\` | where QUERY_STRING(['email'], '.com') | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=\\`accounts\\` | where ISNOTNULL(\\`email\\`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=\\`accounts\\` | where \\`firstname\\` ='Amber' | stats COUNT() AS \\`count\\`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=\\`accounts\\` | where \\`age\\` > 33 | stats COUNT() AS \\`count\\`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=\\`accounts\\` | stats DISTINCT_COUNT(age) AS \\`distinct_count\\`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\` BY \\`gender\\`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=\\`accounts\\` | stats AVG(\\`age\\`) AS \\`avg_age\\`, MIN(\\`age\\`) AS \\`min_age\\`, MAX(\\`age\\`) AS \\`max_age\\`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=\\`accounts\\` | stats AVG(\\`balance\\`) AS \\`avg_balance\\` BY \\`state\\` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('[email protected]')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'clothing') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\` by SPAN(\\`order_date\\`, 2h) AS \\`span\\`, \\`geoip.city_name\\`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'shoes') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(\\`taxful_total_price\\`) AS \\`revenue\\` by SPAN(\\`order_date\\`, 1d) AS \\`span\\`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '!200') AND \\`observerTime\\` >= '2023-03-01 00:00:00' AND \\`observerTime\\` < '2023-04-01 00:00:00' | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=\\`events\\` | where \\`category\\` = 'web' AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(\\`observerTime\\`) >= 2 AND DAY_OF_WEEK(\\`observerTime\\`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(\\`observerTime\\`, 'yyyy-MM-dd')) AS \\`distinct_count\\`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=\\`events\\` | stats SUM(\\`http.response.bytes\\`) AS \\`sum_bytes\\` by \\`trace_id\\` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=\\`events\\` | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a \\`text\\` or \\`keyword\\` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type \\`date\\` and not \\`long\\`.\n#02 You must pick a field with \\`date\\` type when filtering on date/time.\n#03 You must pick a field with \\`date\\` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of \\`log\\`, \\`body\\`, \\`message\\`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where \\`timestamp\\` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where \\`timestamp\\` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where \\`timestamp\\` < '2023-01-01 00:00:00''. Do not use \\`DATE_FORMAT()\\`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(\\`<field>\\`, <interval>)' must have type \\`date\\`, not \\`long\\`.\n#05 When aggregating by \\`SPAN\\` and another field, put \\`SPAN\\` after \\`by\\` and before the other field, eg. 'stats COUNT() AS \\`count\\` by SPAN(\\`timestamp\\`, 1d) AS \\`span\\`, \\`category\\`'.\n#06 You must put values in quotes when filtering fields with \\`text\\` or \\`keyword\\` field type.\n#07 To find documents that contain certain phrases in string fields, use \\`QUERY_STRING\\` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. \\`integer\\`), then use 'where \\`status_code\\` >= 400'; if the field is a string (eg. \\`text\\` or \\`keyword\\`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nPlease only contain PPL inside your response.\n----------------\nQuestion: ${indexInfo.question}? index is \\`${indexInfo.indexName}\\`\nFields:\n${indexInfo.mappingInfo}\n\nAssistant:",
"FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question} Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n"
}
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/agent/tools/PPLToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ public void testTool() {

}

@Test
public void testTool_with_DefaultPrompt() {
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult -> {
Map<String, String> returnResults = gson.fromJson(executePPLResult, Map.class);
assertEquals("ppl result", returnResults.get("executionResult"));
assertEquals("source=demo| head 1", returnResults.get("ppl"));
}, e -> { log.info(e); }));

}

@Test
public void testTool_withPPLTag() {
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
Expand Down Expand Up @@ -173,6 +186,15 @@ public void testTool_querySystemIndex() {
);
}

@Test
public void testTool_WrongModelType() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "wrong_model_type"))
);
assertEquals("Wrong PPL Model type, should be CLAUDE or FINETUNE", exception.getMessage());
}

@Test
public void testTool_getMappingFailure() {
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
Expand Down