Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model related field for tools #491

Merged
merged 9 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.agent.tools;

import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.util.Arrays;
Expand Down Expand Up @@ -34,8 +35,8 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -48,7 +49,7 @@

@Log4j2
@ToolAnnotation(CreateAlertTool.TYPE)
public class CreateAlertTool implements Tool {
public class CreateAlertTool implements WithModelTool {
public static final String TYPE = "CreateAlertTool";

private static final String DEFAULT_DESCRIPTION =
Expand All @@ -70,7 +71,6 @@
@Getter
private final String toolPrompt;

private static final String MODEL_ID = "model_id";
private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json";
private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context";
private static final Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH);
Expand Down Expand Up @@ -276,7 +276,7 @@
return getIndexRequest;
}

public static class Factory implements Tool.Factory<CreateAlertTool> {
public static class Factory implements WithModelTool.Factory<CreateAlertTool> {

private Client client;

Expand All @@ -301,7 +301,7 @@

@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
if (org.apache.commons.lang3.StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
Expand All @@ -324,5 +324,10 @@
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);

Check warning on line 330 in src/main/java/org/opensearch/agent/tools/CreateAlertTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/CreateAlertTool.java#L330

Added line #L330 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

Expand All @@ -16,6 +17,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand All @@ -39,8 +41,8 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

Expand All @@ -66,7 +68,7 @@
@Setter
@Getter
@ToolAnnotation(CreateAnomalyDetectorTool.TYPE)
public class CreateAnomalyDetectorTool implements Tool {
public class CreateAnomalyDetectorTool implements WithModelTool {
// the type of this tool
public static final String TYPE = "CreateAnomalyDetectorTool";

Expand Down Expand Up @@ -395,7 +397,7 @@
/**
* The tool factory
*/
public static class Factory implements Tool.Factory<CreateAnomalyDetectorTool> {
public static class Factory implements WithModelTool.Factory<CreateAnomalyDetectorTool> {
private Client client;

private static CreateAnomalyDetectorTool.Factory INSTANCE;
Expand Down Expand Up @@ -427,7 +429,7 @@
*/
@Override
public CreateAnomalyDetectorTool create(Map<String, Object> map) {
String modelId = (String) map.getOrDefault("model_id", "");
String modelId = (String) map.getOrDefault(COMMON_MODEL_ID_FIELD, "");
if (modelId.isEmpty()) {
throw new IllegalArgumentException("model_id cannot be empty.");
}
Expand Down Expand Up @@ -457,5 +459,10 @@
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);

Check warning on line 465 in src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java#L465

Added line #L465 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;

import lombok.Builder;
import lombok.Getter;
Expand All @@ -29,9 +32,8 @@
@Getter
@Setter
@ToolAnnotation(NeuralSparseSearchTool.TYPE)
public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public class NeuralSparseSearchTool extends AbstractRetrieverTool implements WithModelTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String NESTED_PATH_FIELD = "nested_path";

Expand Down Expand Up @@ -61,7 +63,7 @@
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
"Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty."
);
}

Expand Down Expand Up @@ -101,7 +103,9 @@
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> {
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool>
implements
WithModelTool.Factory<NeuralSparseSearchTool> {
private static Factory INSTANCE;

public static Factory getInstance() {
Expand All @@ -122,7 +126,7 @@
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return NeuralSparseSearchTool
Expand All @@ -147,5 +151,10 @@
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);

Check warning on line 157 in src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java#L157

Added line #L157 was not covered by tests
}
}
}
13 changes: 9 additions & 4 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import java.io.IOException;
Expand Down Expand Up @@ -44,8 +45,8 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.search.SearchHit;
Expand All @@ -65,7 +66,7 @@
@Setter
@Getter
@ToolAnnotation(PPLTool.TYPE)
public class PPLTool implements Tool {
public class PPLTool implements WithModelTool {

public static final String TYPE = "PPLTool";

Expand Down Expand Up @@ -289,7 +290,7 @@
return parameters != null && !parameters.isEmpty();
}

public static class Factory implements Tool.Factory<PPLTool> {
public static class Factory implements WithModelTool.Factory<PPLTool> {
private Client client;

private static Factory INSTANCE;
Expand All @@ -316,7 +317,7 @@
validatePPLToolParameters(map);
return new PPLTool(
client,
(String) map.get("model_id"),
(String) map.get(COMMON_MODEL_ID_FIELD),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
(String) map.getOrDefault("previous_tool_name", ""),
Expand All @@ -340,6 +341,10 @@
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);

Check warning on line 346 in src/main/java/org/opensearch/agent/tools/PPLTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/PPLTool.java#L346

Added line #L346 was not covered by tests
}
}

private SearchRequest buildSearchRequest(String indexName) {
Expand Down
16 changes: 11 additions & 5 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
Expand All @@ -25,8 +26,8 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

Expand All @@ -45,7 +46,7 @@
@Setter
@Getter
@ToolAnnotation(RAGTool.TYPE)
public class RAGTool implements Tool {
public class RAGTool implements WithModelTool {
public static final String TYPE = "RAGTool";
public static String DEFAULT_DESCRIPTION =
"Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions.";
Expand Down Expand Up @@ -197,7 +198,7 @@
/**
* Factory class to create RAGTool
*/
public static class Factory implements Tool.Factory<RAGTool> {
public static class Factory implements WithModelTool.Factory<RAGTool> {
private Client client;
private NamedXContentRegistry xContentRegistry;

Expand Down Expand Up @@ -231,7 +232,7 @@
String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : "";
switch (queryType) {
case "neural_sparse":
params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId);
params.put(COMMON_MODEL_ID_FIELD, embeddingModelId);
NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
return RAGTool
.builder()
Expand All @@ -242,7 +243,7 @@
.queryTool(neuralSparseSearchTool)
.build();
case "neural":
params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId);
params.put(COMMON_MODEL_ID_FIELD, embeddingModelId);
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);
return RAGTool
.builder()
Expand Down Expand Up @@ -273,5 +274,10 @@
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(INFERENCE_MODEL_ID_FIELD, EMBEDDING_MODEL_ID_FIELD);

Check warning on line 280 in src/main/java/org/opensearch/agent/tools/RAGTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/RAGTool.java#L280

Added line #L280 was not covered by tests
}
}
}
17 changes: 12 additions & 5 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;

import lombok.Builder;
import lombok.Getter;
Expand All @@ -29,12 +32,11 @@
@Getter
@Setter
@ToolAnnotation(VectorDBTool.TYPE)
public class VectorDBTool extends AbstractRetrieverTool {
public class VectorDBTool extends AbstractRetrieverTool implements WithModelTool {
public static final String TYPE = "VectorDBTool";

public static String DEFAULT_DESCRIPTION =
"Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query.";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String K_FIELD = "k";
public static final Integer DEFAULT_K = 10;
Expand Down Expand Up @@ -69,7 +71,7 @@
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
"Parameter [" + EMBEDDING_FIELD + "] and [" + COMMON_MODEL_ID_FIELD + "] can not be null or empty."
);
}

Expand Down Expand Up @@ -110,7 +112,7 @@
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<VectorDBTool> {
public static class Factory extends AbstractRetrieverTool.Factory<VectorDBTool> implements WithModelTool.Factory<VectorDBTool> {
private static VectorDBTool.Factory INSTANCE;

public static VectorDBTool.Factory getInstance() {
Expand All @@ -131,7 +133,7 @@
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
Expand Down Expand Up @@ -163,5 +165,10 @@
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public List<String> getAllModelKeys() {
return List.of(COMMON_MODEL_ID_FIELD);

Check warning on line 171 in src/main/java/org/opensearch/agent/tools/VectorDBTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/VectorDBTool.java#L171

Added line #L171 was not covered by tests
}
}
}
Loading
Loading