Skip to content

Commit

Permalink
Add model related field for tools (#491)
Browse files Browse the repository at this point in the history
* add related model fields list

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* fix compile error

Signed-off-by: xinyual <[email protected]>

* use static parameter

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* create a common constant for common model id

Signed-off-by: xinyual <[email protected]>

* fix error

Signed-off-by: xinyual <[email protected]>

* fix UT error

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
(cherry picked from commit a2f9e2c)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Jan 26, 2025
1 parent b201795 commit 33681a9
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 35 deletions.
15 changes: 10 additions & 5 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.agent.tools;

import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.util.Arrays;
Expand Down Expand Up @@ -34,8 +35,8 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -48,7 +49,7 @@

@Log4j2
@ToolAnnotation(CreateAlertTool.TYPE)
public class CreateAlertTool implements Tool {
public class CreateAlertTool implements WithModelTool {
public static final String TYPE = "CreateAlertTool";

private static final String DEFAULT_DESCRIPTION =
Expand All @@ -70,7 +71,6 @@ public class CreateAlertTool implements Tool {
@Getter
private final String toolPrompt;

private static final String MODEL_ID = "model_id";
private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json";
private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context";
private static final Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH);
Expand Down Expand Up @@ -276,7 +276,7 @@ private static GetIndexRequest constructIndexRequest(String rawIndex, Boolean is
return getIndexRequest;
}

public static class Factory implements Tool.Factory<CreateAlertTool> {
public static class Factory implements WithModelTool.Factory<CreateAlertTool> {

private Client client;

Expand All @@ -301,7 +301,7 @@ public void init(Client client) {

@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
if (org.apache.commons.lang3.StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
Expand All @@ -324,5 +324,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

Expand All @@ -16,6 +17,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand All @@ -39,8 +41,8 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

Expand All @@ -66,7 +68,7 @@
@Setter
@Getter
@ToolAnnotation(CreateAnomalyDetectorTool.TYPE)
public class CreateAnomalyDetectorTool implements Tool {
public class CreateAnomalyDetectorTool implements WithModelTool {
// the type of this tool
public static final String TYPE = "CreateAnomalyDetectorTool";

Expand Down Expand Up @@ -395,7 +397,7 @@ public String getType() {
/**
* The tool factory
*/
public static class Factory implements Tool.Factory<CreateAnomalyDetectorTool> {
public static class Factory implements WithModelTool.Factory<CreateAnomalyDetectorTool> {
private Client client;

private static CreateAnomalyDetectorTool.Factory INSTANCE;
Expand Down Expand Up @@ -427,7 +429,7 @@ public void init(Client client) {
*/
@Override
public CreateAnomalyDetectorTool create(Map<String, Object> map) {
String modelId = (String) map.getOrDefault("model_id", "");
String modelId = (String) map.getOrDefault(COMMON_MODEL_ID_FIELD, "");
if (modelId.isEmpty()) {
throw new IllegalArgumentException("model_id cannot be empty.");
}
Expand Down Expand Up @@ -457,5 +459,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;

import lombok.Builder;
import lombok.Getter;
Expand All @@ -29,9 +32,8 @@
@Getter
@Setter
@ToolAnnotation(NeuralSparseSearchTool.TYPE)
public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public class NeuralSparseSearchTool extends AbstractRetrieverTool implements WithModelTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String NESTED_PATH_FIELD = "nested_path";

Expand Down Expand Up @@ -61,7 +63,7 @@ public NeuralSparseSearchTool(
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
"Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty."
);
}

Expand Down Expand Up @@ -101,7 +103,9 @@ public String getType() {
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> {
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool>
implements
WithModelTool.Factory<NeuralSparseSearchTool> {
private static Factory INSTANCE;

public static Factory getInstance() {
Expand All @@ -122,7 +126,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return NeuralSparseSearchTool
Expand All @@ -147,5 +151,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);
}
}
}
13 changes: 9 additions & 4 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import java.io.IOException;
Expand Down Expand Up @@ -44,8 +45,8 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
Expand All @@ -65,7 +66,7 @@
@Setter
@Getter
@ToolAnnotation(PPLTool.TYPE)
public class PPLTool implements Tool {
public class PPLTool implements WithModelTool {

public static final String TYPE = "PPLTool";

Expand Down Expand Up @@ -285,7 +286,7 @@ public boolean validate(Map<String, String> parameters) {
return parameters != null && !parameters.isEmpty();
}

public static class Factory implements Tool.Factory<PPLTool> {
public static class Factory implements WithModelTool.Factory<PPLTool> {
private Client client;

private static Factory INSTANCE;
Expand All @@ -312,7 +313,7 @@ public PPLTool create(Map<String, Object> map) {
validatePPLToolParameters(map);
return new PPLTool(
client,
(String) map.get("model_id"),
(String) map.get(COMMON_MODEL_ID_FIELD),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
(String) map.getOrDefault("previous_tool_name", ""),
Expand All @@ -336,6 +337,10 @@ public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);
}
}

private SearchRequest buildSearchRequest(String indexName) {
Expand Down
16 changes: 11 additions & 5 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
Expand All @@ -25,8 +26,8 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

Expand All @@ -45,7 +46,7 @@
@Setter
@Getter
@ToolAnnotation(RAGTool.TYPE)
public class RAGTool implements Tool {
public class RAGTool implements WithModelTool {
public static final String TYPE = "RAGTool";
public static String DEFAULT_DESCRIPTION =
"Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions.";
Expand Down Expand Up @@ -197,7 +198,7 @@ public boolean validate(Map<String, String> parameters) {
/**
* Factory class to create RAGTool
*/
public static class Factory implements Tool.Factory<RAGTool> {
public static class Factory implements WithModelTool.Factory<RAGTool> {
private Client client;
private NamedXContentRegistry xContentRegistry;

Expand Down Expand Up @@ -231,7 +232,7 @@ public RAGTool create(Map<String, Object> params) {
String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : "";
switch (queryType) {
case "neural_sparse":
params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId);
params.put(COMMON_MODEL_ID_FIELD, embeddingModelId);
NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
return RAGTool
.builder()
Expand All @@ -242,7 +243,7 @@ public RAGTool create(Map<String, Object> params) {
.queryTool(neuralSparseSearchTool)
.build();
case "neural":
params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId);
params.put(COMMON_MODEL_ID_FIELD, embeddingModelId);
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);
return RAGTool
.builder()
Expand Down Expand Up @@ -273,5 +274,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(INFERENCE_MODEL_ID_FIELD, EMBEDDING_MODEL_ID_FIELD);
}
}
}
17 changes: 12 additions & 5 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;

import lombok.Builder;
import lombok.Getter;
Expand All @@ -29,12 +32,11 @@
@Getter
@Setter
@ToolAnnotation(VectorDBTool.TYPE)
public class VectorDBTool extends AbstractRetrieverTool {
public class VectorDBTool extends AbstractRetrieverTool implements WithModelTool {
public static final String TYPE = "VectorDBTool";

public static String DEFAULT_DESCRIPTION =
"Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query.";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String K_FIELD = "k";
public static final Integer DEFAULT_K = 10;
Expand Down Expand Up @@ -69,7 +71,7 @@ public VectorDBTool(
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
"Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty."
);
}

Expand Down Expand Up @@ -110,7 +112,7 @@ public String getType() {
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<VectorDBTool> {
public static class Factory extends AbstractRetrieverTool.Factory<VectorDBTool> implements WithModelTool.Factory<VectorDBTool> {
private static VectorDBTool.Factory INSTANCE;

public static VectorDBTool.Factory getInstance() {
Expand All @@ -131,7 +133,7 @@ public VectorDBTool create(Map<String, Object> params) {
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
Expand Down Expand Up @@ -163,5 +165,10 @@ public String getDefaultVersion() {
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);
}
}
}
Loading

0 comments on commit 33681a9

Please sign in to comment.