Skip to content

Commit

Permalink
Add CreateAnomalyDetectorTool (opensearch-project#348)
Browse files Browse the repository at this point in the history
* Add CreateAnomalyDetectorTool

Signed-off-by: gaobinlong <[email protected]>

* Optimize some code

Signed-off-by: gaobinlong <[email protected]>

* Fix test failure

Signed-off-by: gaobinlong <[email protected]>

* Optimize exception

Signed-off-by: gaobinlong <[email protected]>

---------

Signed-off-by: gaobinlong <[email protected]>

Fix warning and format issue for CreateAnomalyDetectorTool (opensearch-project#358)

Signed-off-by: gaobinlong <[email protected]>

Add includeFields parameter to the method extractFieldNamesTypes (opensearch-project#376)

* Add includeFields parameter to the method extractFieldNamesTypes

Signed-off-by: gaobinlong <[email protected]>

* Remove empty line

Signed-off-by: gaobinlong <[email protected]>

---------

Signed-off-by: gaobinlong <[email protected]>

Optimize the prompt for create anomaly detector tool

Signed-off-by: gaobinlong <[email protected]>
  • Loading branch information
gaobinlong committed Sep 3, 2024
1 parent 9406558 commit 00fd8ce
Show file tree
Hide file tree
Showing 9 changed files with 1,241 additions and 24 deletions.
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.RAGTool;
Expand Down Expand Up @@ -68,6 +69,7 @@ public Collection<Object> createComponents(
SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
return Collections.emptyList();
}

Expand All @@ -82,7 +84,8 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAlertsTool.Factory.getInstance(),
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance()
SearchMonitorsTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
);
}
}

Large diffs are not rendered by default.

25 changes: 2 additions & 23 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -396,7 +397,7 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
);
}
Map<String, String> fieldsToType = new HashMap<>();
extractNamesTypes(mappingSource, fieldsToType, "");
ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", false);
StringJoiner tableInfoJoiner = new StringJoiner("\n");
List<String> sortedKeys = new ArrayList<>(fieldsToType.keySet());
Collections.sort(sortedKeys);
Expand Down Expand Up @@ -436,28 +437,6 @@ private String constructPrompt(String tableInfo, String question, String indexNa
return finalPrompt;
}

private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (prefix.length() > 0) {
prefix += ".";
}

for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();

if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
if (!((vMap.getOrDefault("type", "")).equals("alias"))) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
} else if (vMap.containsKey("properties")) {
extractNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n);
}
}
}
}

private static void extractSamples(Map<String, Object> sampleSource, Map<String, String> fieldsToSample, String prefix)
throws PrivilegedActionException {
if (prefix.length() > 0) {
Expand Down
51 changes: 51 additions & 0 deletions src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools.utils;

import java.util.Map;

public class ToolHelper {
/**
* Flatten all the fields in the mappings, insert the field to fieldType mapping to a map
* @param mappingSource the mappings of an index
* @param fieldsToType the result containing the field to fieldType mapping
* @param prefix the parent field path
* @param includeFields whether include the `fields` in a text type field, for some use case like PPLTool, `fields` in a text type field
* cannot be included, but for CreateAnomalyDetectorTool, `fields` must be included.
*/
public static void extractFieldNamesTypes(
Map<String, Object> mappingSource,
Map<String, String> fieldsToType,
String prefix,
boolean includeFields
) {
if (prefix.length() > 0) {
prefix += ".";
}

for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();

if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
String fieldType = (String) vMap.getOrDefault("type", "");
// no need to extract alias into the result, and for object field, extract the subfields only
if (!fieldType.equals("alias") && !fieldType.equals("object")) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
}
if (vMap.containsKey("properties")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n, includeFields);
}
if (includeFields && vMap.containsKey("fields")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("fields"), fieldsToType, prefix + n, true);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
}
Loading

0 comments on commit 00fd8ce

Please sign in to comment.