From dfccb42ef35dfd549ce0483c5a34dd41de4a7f7b Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 24 Jan 2025 15:02:58 +0800 Subject: [PATCH 1/8] add related model fields list Signed-off-by: xinyual --- .../agent/tools/AbstractRetrieverTool.java | 1 + .../org/opensearch/agent/tools/CreateAlertTool.java | 10 ++++++++-- .../agent/tools/CreateAnomalyDetectorTool.java | 13 +++++++++++-- .../agent/tools/NeuralSparseSearchTool.java | 11 +++++++++-- .../java/org/opensearch/agent/tools/PPLTool.java | 13 ++++++++++--- .../java/org/opensearch/agent/tools/RAGTool.java | 10 ++++++++-- .../org/opensearch/agent/tools/VectorDBTool.java | 11 +++++++++-- 7 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index 5abcd758..be47c179 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -23,6 +23,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index f7bc3311..38b028c0 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; @@ -48,7 +49,7 @@ @Log4j2 @ToolAnnotation(CreateAlertTool.TYPE) -public class CreateAlertTool implements Tool { +public class CreateAlertTool implements WithModelTool { public static final String TYPE = "CreateAlertTool"; private static final String DEFAULT_DESCRIPTION = @@ -276,7 +277,7 @@ private static GetIndexRequest constructIndexRequest(String rawIndex, Boolean is return getIndexRequest; } - public static class Factory implements Tool.Factory { + public static class Factory implements WithModelTool.Factory { private Client client; @@ -324,5 +325,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID); + } } } diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index e1540669..1b8f6d35 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -15,6 +15,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -40,6 +41,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -65,10 +67,12 @@ @Setter @Getter @ToolAnnotation(CreateAnomalyDetectorTool.TYPE) -public class CreateAnomalyDetectorTool implements Tool { +public class CreateAnomalyDetectorTool implements WithModelTool { // the type of this tool public static final String TYPE = "CreateAnomalyDetectorTool"; + public static final String MODEL_ID_FIELD = "model_id"; + // the default description of this tool private static final String DEFAULT_DESCRIPTION = "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector."; @@ -392,7 +396,7 @@ public String getType() { /** * The tool factory */ - public static class Factory implements Tool.Factory { + public static class Factory implements WithModelTool.Factory { private Client client; private static CreateAnomalyDetectorTool.Factory INSTANCE; @@ -454,5 +458,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID_FIELD); + } } } diff --git a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java index 60168603..231261b9 100644 --- a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java @@ -10,6 +10,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; @@ -21,6 +22,7 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.spi.tools.WithModelTool; /** * This tool supports neural_sparse search with sparse encoding models and rank_features field. @@ -29,7 +31,7 @@ @Getter @Setter @ToolAnnotation(NeuralSparseSearchTool.TYPE) -public class NeuralSparseSearchTool extends AbstractRetrieverTool { +public class NeuralSparseSearchTool extends AbstractRetrieverTool implements WithModelTool { public static final String TYPE = "NeuralSparseSearchTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; @@ -101,7 +103,7 @@ public String getType() { return TYPE; } - public static class Factory extends AbstractRetrieverTool.Factory { + public static class Factory extends AbstractRetrieverTool.Factory implements WithModelTool.Factory { private static Factory INSTANCE; public static Factory getInstance() { @@ -147,5 +149,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID_FIELD); + } } } diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 30eaf555..c33a4c78 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.search.SearchHit; @@ -63,10 +64,12 @@ @Setter @Getter @ToolAnnotation(PPLTool.TYPE) -public class PPLTool implements Tool { +public class PPLTool implements WithModelTool { public static final String TYPE = "PPLTool"; + public static final String MODEL_ID_FIELD = "model_id"; + @Setter private Client client; @@ -285,7 +288,7 @@ public boolean validate(Map parameters) { return parameters != null && !parameters.isEmpty(); } - public static class Factory implements Tool.Factory { + public static class Factory implements WithModelTool.Factory { private Client client; private static Factory INSTANCE; @@ -312,7 +315,7 @@ public PPLTool create(Map map) { validatePPLToolParameters(map); return new PPLTool( client, - (String) map.get("model_id"), + (String) map.get(MODEL_ID_FIELD), (String) map.getOrDefault("prompt", ""), (String) map.getOrDefault("model_type", ""), (String) map.getOrDefault("previous_tool_name", ""), @@ -336,6 +339,10 @@ public String getDefaultVersion() { return null; } + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID_FIELD); + } } private SearchRequest buildSearchRequest(String indexName) { diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index 8b1193d1..b591d3e3 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -44,7 +45,7 @@ @Setter @Getter @ToolAnnotation(RAGTool.TYPE) -public class RAGTool implements Tool { +public class RAGTool implements WithModelTool { public static final String TYPE = "RAGTool"; public static String DEFAULT_DESCRIPTION = "Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions."; @@ -194,7 +195,7 @@ public boolean validate(Map parameters) { /** * Factory class to create RAGTool */ - public static class Factory implements Tool.Factory { + public static class Factory implements WithModelTool.Factory { private Client client; private NamedXContentRegistry xContentRegistry; @@ -270,5 +271,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public List getAllModelKeys() { + return List.of(INFERENCE_MODEL_ID_FIELD, EMBEDDING_MODEL_ID_FIELD); + } } } diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index d397060e..32124d9e 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -10,6 +10,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; @@ -21,6 +22,7 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.spi.tools.WithModelTool; /** * This tool supports neural search with embedding models and knn index. @@ -29,7 +31,7 @@ @Getter @Setter @ToolAnnotation(VectorDBTool.TYPE) -public class VectorDBTool extends AbstractRetrieverTool { +public class VectorDBTool extends AbstractRetrieverTool implements WithModelTool { public static final String TYPE = "VectorDBTool"; public static String DEFAULT_DESCRIPTION = @@ -110,7 +112,7 @@ public String getType() { return TYPE; } - public static class Factory extends AbstractRetrieverTool.Factory { + public static class Factory extends AbstractRetrieverTool.Factory implements WithModelTool.Factory { private static VectorDBTool.Factory INSTANCE; public static VectorDBTool.Factory getInstance() { @@ -163,5 +165,10 @@ public String getDefaultVersion() { public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID_FIELD); + } } } From c18ef7e2064ce0857c5409d900c044cbd04dc805 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 24 Jan 2025 15:29:06 +0800 Subject: [PATCH 2/8] apply spotless Signed-off-by: xinyual --- .../org/opensearch/agent/tools/AbstractRetrieverTool.java | 1 - .../java/org/opensearch/agent/tools/CreateAlertTool.java | 1 - .../opensearch/agent/tools/CreateAnomalyDetectorTool.java | 1 - .../org/opensearch/agent/tools/NeuralSparseSearchTool.java | 6 ++++-- src/main/java/org/opensearch/agent/tools/PPLTool.java | 1 - src/main/java/org/opensearch/agent/tools/RAGTool.java | 1 - src/main/java/org/opensearch/agent/tools/VectorDBTool.java | 6 ++++-- 7 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index be47c179..5abcd758 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -23,7 +23,6 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 38b028c0..8d0ca847 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -34,7 +34,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 03e718eb..5ff0e2da 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -40,7 +40,6 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; diff --git a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java index 231261b9..dbab1936 100644 --- a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java @@ -17,12 +17,12 @@ import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import lombok.Builder; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.opensearch.ml.common.spi.tools.WithModelTool; /** * This tool supports neural_sparse search with sparse encoding models and rank_features field. @@ -103,7 +103,9 @@ public String getType() { return TYPE; } - public static class Factory extends AbstractRetrieverTool.Factory implements WithModelTool.Factory { + public static class Factory extends AbstractRetrieverTool.Factory + implements + WithModelTool.Factory { private static Factory INSTANCE; public static Factory getInstance() { diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index d072200c..00e8f3e6 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -44,7 +44,6 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index 23770b16..f0556419 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -25,7 +25,6 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index 32124d9e..24ee5d73 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -17,12 +17,12 @@ import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import lombok.Builder; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.opensearch.ml.common.spi.tools.WithModelTool; /** * This tool supports neural search with embedding models and knn index. @@ -112,7 +112,9 @@ public String getType() { return TYPE; } - public static class Factory extends AbstractRetrieverTool.Factory implements WithModelTool.Factory { + public static class Factory extends AbstractRetrieverTool.Factory + implements + WithModelTool.Factory { private static VectorDBTool.Factory INSTANCE; public static VectorDBTool.Factory getInstance() { From 328e41a02267038fece2382f523d2f1cd9b9b2b8 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 24 Jan 2025 15:49:14 +0800 Subject: [PATCH 3/8] fix compile error Signed-off-by: xinyual --- src/main/java/org/opensearch/agent/tools/VectorDBTool.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index 24ee5d73..70e244de 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -114,7 +114,7 @@ public String getType() { public static class Factory extends AbstractRetrieverTool.Factory implements - WithModelTool.Factory { + WithModelTool.Factory { private static VectorDBTool.Factory INSTANCE; public static VectorDBTool.Factory getInstance() { From eb12f3acc523be8c58b057a0a7d953151e11b2ec Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 24 Jan 2025 15:50:14 +0800 Subject: [PATCH 4/8] use static parameter Signed-off-by: xinyual --- .../org/opensearch/agent/tools/CreateAnomalyDetectorTool.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 5ff0e2da..36ed5fcc 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -430,7 +430,7 @@ public void init(Client client) { */ @Override public CreateAnomalyDetectorTool create(Map map) { - String modelId = (String) map.getOrDefault("model_id", ""); + String modelId = (String) map.getOrDefault(MODEL_ID_FIELD, ""); if (modelId.isEmpty()) { throw new IllegalArgumentException("model_id cannot be empty."); } From b278f21a3c0412eb91da8c3291e0886a5fd5888d Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 24 Jan 2025 16:04:32 +0800 Subject: [PATCH 5/8] apply spotless Signed-off-by: xinyual --- src/main/java/org/opensearch/agent/tools/VectorDBTool.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index 70e244de..de5ef68e 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -112,9 +112,7 @@ public String getType() { return TYPE; } - public static class Factory extends AbstractRetrieverTool.Factory - implements - WithModelTool.Factory { + public static class Factory extends AbstractRetrieverTool.Factory implements WithModelTool.Factory { private static VectorDBTool.Factory INSTANCE; public static VectorDBTool.Factory getInstance() { From ad1af2214d3419df064c946438ce6fc48d3fc291 Mon Sep 17 00:00:00 2001 From: xinyual Date: Sun, 26 Jan 2025 11:25:32 +0800 Subject: [PATCH 6/8] create a common constant for common model id Signed-off-by: xinyual --- .../org/opensearch/agent/tools/CreateAlertTool.java | 6 +++--- .../agent/tools/CreateAnomalyDetectorTool.java | 7 +++---- .../opensearch/agent/tools/NeuralSparseSearchTool.java | 8 ++++---- src/main/java/org/opensearch/agent/tools/PPLTool.java | 7 +++---- .../java/org/opensearch/agent/tools/VectorDBTool.java | 8 ++++---- .../opensearch/agent/tools/utils/CommonConstants.java | 10 ++++++++++ 6 files changed, 27 insertions(+), 19 deletions(-) create mode 100644 src/main/java/org/opensearch/agent/tools/utils/CommonConstants.java diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 8d0ca847..57e306a3 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -6,6 +6,7 @@ package org.opensearch.agent.tools; import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.isJson; import java.util.Arrays; @@ -70,7 +71,6 @@ public class CreateAlertTool implements WithModelTool { @Getter private final String toolPrompt; - private static final String MODEL_ID = "model_id"; private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json"; private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context"; private static final Map promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH); @@ -301,7 +301,7 @@ public void init(Client client) { @Override public CreateAlertTool create(Map params) { - String modelId = (String) params.get(MODEL_ID); + String modelId = (String) params.get(COMMON_MODEL_ID_FIELD); if (org.apache.commons.lang3.StringUtils.isBlank(modelId)) { throw new IllegalArgumentException("model_id cannot be null or blank."); } @@ -327,7 +327,7 @@ public String getDefaultVersion() { @Override public List getAllModelKeys() { - return List.of(MODEL_ID); + return List.of(COMMON_MODEL_ID_FIELD); } } } diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 36ed5fcc..0f33d38f 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -71,8 +72,6 @@ public class CreateAnomalyDetectorTool implements WithModelTool { // the type of this tool public static final String TYPE = "CreateAnomalyDetectorTool"; - public static final String MODEL_ID_FIELD = "model_id"; - // the default description of this tool private static final String DEFAULT_DESCRIPTION = "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector."; @@ -430,7 +429,7 @@ public void init(Client client) { */ @Override public CreateAnomalyDetectorTool create(Map map) { - String modelId = (String) map.getOrDefault(MODEL_ID_FIELD, ""); + String modelId = (String) map.getOrDefault(COMMON_MODEL_ID_FIELD, ""); if (modelId.isEmpty()) { throw new IllegalArgumentException("model_id cannot be empty."); } @@ -463,7 +462,7 @@ public String getDefaultVersion() { @Override public List getAllModelKeys() { - return List.of(MODEL_ID_FIELD); + return List.of(COMMON_MODEL_ID_FIELD); } } } diff --git a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java index dbab1936..0440fe6b 100644 --- a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.security.AccessController; @@ -33,7 +34,6 @@ @ToolAnnotation(NeuralSparseSearchTool.TYPE) public class NeuralSparseSearchTool extends AbstractRetrieverTool implements WithModelTool { public static final String TYPE = "NeuralSparseSearchTool"; - public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; public static final String NESTED_PATH_FIELD = "nested_path"; @@ -63,7 +63,7 @@ public NeuralSparseSearchTool( protected String getQueryBody(String queryText) { if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + "Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty." ); } @@ -126,7 +126,7 @@ public NeuralSparseSearchTool create(Map params) { String index = (String) params.get(INDEX_FIELD); String embeddingField = (String) params.get(EMBEDDING_FIELD); String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); - String modelId = (String) params.get(MODEL_ID_FIELD); + String modelId = (String) params.get(COMMON_MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; String nestedPath = (String) params.get(NESTED_PATH_FIELD); return NeuralSparseSearchTool @@ -154,7 +154,7 @@ public String getDefaultVersion() { @Override public List getAllModelKeys() { - return List.of(MODEL_ID_FIELD); + return List.of(COMMON_MODEL_ID_FIELD); } } } diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 00e8f3e6..776aae6a 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import java.io.IOException; @@ -69,8 +70,6 @@ public class PPLTool implements WithModelTool { public static final String TYPE = "PPLTool"; - public static final String MODEL_ID_FIELD = "model_id"; - @Setter private Client client; @@ -318,7 +317,7 @@ public PPLTool create(Map map) { validatePPLToolParameters(map); return new PPLTool( client, - (String) map.get(MODEL_ID_FIELD), + (String) map.get(COMMON_MODEL_ID_FIELD), (String) map.getOrDefault("prompt", ""), (String) map.getOrDefault("model_type", ""), (String) map.getOrDefault("previous_tool_name", ""), @@ -344,7 +343,7 @@ public String getDefaultVersion() { @Override public List getAllModelKeys() { - return List.of(MODEL_ID_FIELD); + return List.of(COMMON_MODEL_ID_FIELD); } } diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index de5ef68e..968ad1bd 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.security.AccessController; @@ -36,7 +37,6 @@ public class VectorDBTool extends AbstractRetrieverTool implements WithModelTool public static String DEFAULT_DESCRIPTION = "Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query."; - public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; public static final String K_FIELD = "k"; public static final Integer DEFAULT_K = 10; @@ -71,7 +71,7 @@ public VectorDBTool( protected String getQueryBody(String queryText) { if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + "Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty." ); } @@ -133,7 +133,7 @@ public VectorDBTool create(Map params) { String index = (String) params.get(INDEX_FIELD); String embeddingField = (String) params.get(EMBEDDING_FIELD); String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); - String modelId = (String) params.get(MODEL_ID_FIELD); + String modelId = (String) params.get(COMMON_MODEL_ID_FIELD); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K; String nestedPath = (String) params.get(NESTED_PATH_FIELD); @@ -168,7 +168,7 @@ public String getDefaultDescription() { @Override public List getAllModelKeys() { - return List.of(MODEL_ID_FIELD); + return List.of(COMMON_MODEL_ID_FIELD); } } } diff --git a/src/main/java/org/opensearch/agent/tools/utils/CommonConstants.java b/src/main/java/org/opensearch/agent/tools/utils/CommonConstants.java new file mode 100644 index 00000000..495b1f95 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/CommonConstants.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +public class CommonConstants { + public static final String COMMON_MODEL_ID_FIELD = "model_id"; +} From 3992dd52e8783ef0d7ad5d7c3168ca3fc7e3dce4 Mon Sep 17 00:00:00 2001 From: xinyual Date: Sun, 26 Jan 2025 11:36:24 +0800 Subject: [PATCH 7/8] fix error Signed-off-by: xinyual --- src/main/java/org/opensearch/agent/tools/RAGTool.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index f0556419..b68dbb69 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -7,6 +7,7 @@ import static org.apache.commons.lang3.StringEscapeUtils.escapeJson; import static org.opensearch.agent.tools.AbstractRetrieverTool.*; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.toJson; @@ -231,7 +232,7 @@ public RAGTool create(Map params) { String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : ""; switch (queryType) { case "neural_sparse": - params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId); + params.put(COMMON_MODEL_ID_FIELD, embeddingModelId); NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params); return RAGTool .builder() @@ -242,7 +243,7 @@ public RAGTool create(Map params) { .queryTool(neuralSparseSearchTool) .build(); case "neural": - params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId); + params.put(COMMON_MODEL_ID_FIELD, embeddingModelId); VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params); return RAGTool .builder() From c67debe68d975c06f07e67ff508d64ae2006750a Mon Sep 17 00:00:00 2001 From: xinyual Date: Sun, 26 Jan 2025 11:41:23 +0800 Subject: [PATCH 8/8] fix UT error Signed-off-by: xinyual --- .../agent/tools/NeuralSparseSearchToolTests.java | 7 ++++--- src/test/java/org/opensearch/agent/tools/RAGToolTests.java | 3 ++- .../java/org/opensearch/agent/tools/VectorDBToolTests.java | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java index d6d14991..1d7a3fbd 100644 --- a/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.HashMap; @@ -31,7 +32,7 @@ public void setup() { params.put(NeuralSparseSearchTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); params.put(NeuralSparseSearchTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); params.put(NeuralSparseSearchTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); - params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(COMMON_MODEL_ID_FIELD, TEST_MODEL_ID); params.put(NeuralSparseSearchTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); } @@ -91,7 +92,7 @@ public void testGetQueryBodyWithJsonObjectString() { @SneakyThrows public void testGetQueryBodyWithIllegalParams() { Map illegalParams1 = new HashMap<>(params); - illegalParams1.remove(NeuralSparseSearchTool.MODEL_ID_FIELD); + illegalParams1.remove(COMMON_MODEL_ID_FIELD); NeuralSparseSearchTool tool1 = NeuralSparseSearchTool.Factory.getInstance().create(illegalParams1); Exception exception1 = assertThrows( IllegalArgumentException.class, @@ -124,7 +125,7 @@ public void testCreateToolsParseParams() { assertThrows( ClassCastException.class, - () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123)) + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(COMMON_MODEL_ID_FIELD, 123)) ); assertThrows( diff --git a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java index 0f19f91a..b6b85da1 100644 --- a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java @@ -17,6 +17,7 @@ import static org.opensearch.agent.tools.AbstractRetrieverTool.*; import static org.opensearch.agent.tools.AbstractRetrieverToolTests.*; import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -426,7 +427,7 @@ public void testFactoryNeuralQuery() { params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH); RAGTool rAGtool1 = factoryMock.create(params); VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); - params.put(VectorDBTool.MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); + params.put(COMMON_MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); VectorDBTool queryTool = VectorDBTool.Factory.getInstance().create(params); RAGTool rAGtool2 = new RAGTool(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY, TEST_INFERENCE_MODEL_ID, true, queryTool); diff --git a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java index 635724a7..12d10171 100644 --- a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.HashMap; @@ -32,7 +33,7 @@ public void setup() { params.put(VectorDBTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); params.put(VectorDBTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); - params.put(VectorDBTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(COMMON_MODEL_ID_FIELD, TEST_MODEL_ID); params.put(VectorDBTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); params.put(VectorDBTool.K_FIELD, TEST_K.toString()); } @@ -93,7 +94,7 @@ public void testGetQueryBodyWithJsonObjectString() { @SneakyThrows public void testGetQueryBodyWithIllegalParams() { Map illegalParams1 = new HashMap<>(params); - illegalParams1.remove(VectorDBTool.MODEL_ID_FIELD); + illegalParams1.remove(COMMON_MODEL_ID_FIELD); VectorDBTool tool1 = VectorDBTool.Factory.getInstance().create(illegalParams1); Exception exception1 = assertThrows( IllegalArgumentException.class, @@ -118,7 +119,7 @@ public void testCreateToolsParseParams() { assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.EMBEDDING_FIELD, 123))); - assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.MODEL_ID_FIELD, 123))); + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(COMMON_MODEL_ID_FIELD, 123))); assertThrows( ClassCastException.class,