Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PainlessScript tool #380

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.agent.tools.LogPatternTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.PainlessScriptTool;
import org.opensearch.agent.tools.RAGTool;
import org.opensearch.agent.tools.SearchAlertsTool;
import org.opensearch.agent.tools.SearchAnomalyDetectorsTool;
Expand Down Expand Up @@ -73,6 +74,7 @@ public Collection<Object> createComponents(
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
LogPatternTool.Factory.getInstance().init(client, xContentRegistry);
PainlessTool.Factory.getInstance().init(scriptService);
return Collections.emptyList();
}

Expand All @@ -91,6 +93,7 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance(),
LogPatternTool.Factory.getInstance()
PainlessTool.Factory.getInstance()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ public void init(Client client) {
@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(COMMON_MODEL_ID_FIELD);
if (org.apache.commons.lang3.StringUtils.isBlank(modelId)) {
if (org.apache.commons.lang3.StringUtils.isNullOrEmpty(modelId) || modelId.isBlank()) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());
Expand Down
154 changes: 154 additions & 0 deletions src/main/java/org/opensearch/agent/tools/PainlessScriptTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
import org.opensearch.script.TemplateScript;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* use case for this tool will only focus on flow agent
*/
@Log4j2
@Setter
@Getter
@ToolAnnotation(PainlessScriptTool.TYPE)
public class PainlessScriptTool implements Tool {
public static final String TYPE = "PainlessTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to execute painless script";

@Setter
@Getter
private String name = TYPE;

@Getter
private String type = TYPE;

@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

@Getter
private String version;

private ScriptService scriptService;
private String scriptCode;

public PainlessScriptTool(ScriptService scriptEngine, String script) {
this.scriptService = scriptEngine;
this.scriptCode = script;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap());
Map<String, Object> flattenedParameters = getFlattenedParameters(parameters);
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters);
try {
String result = templateScript.execute();
listener.onResponse(result == null ? (T) "" : (T) result);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
public boolean validate(Map<String, String> map) {
return true;
}

Map<String, Object> getFlattenedParameters(Map<String, String> parameters) {
Map<String, Object> flattenedParameters = new HashMap<>();
for (Map.Entry<String, String> entry : parameters.entrySet()) {
// keep both original values and flatten
flattenedParameters.put(entry.getKey(), entry.getValue());
try {
// default is json parser, we may add more...
String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue());
Map<String, ?> map = StringUtils.fromJson(value, "");
flattenMap(map, flattenedParameters, entry.getKey());
} catch (Throwable ignored) {}
}
return flattenedParameters;
}

void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String prefix) {
for (Map.Entry<String, ?> entry : map.entrySet()) {
String key = entry.getKey();
if (prefix != null && !prefix.isEmpty()) {
key = prefix + "." + entry.getKey();
}
Object value = entry.getValue();
if (value instanceof Map) {
flattenMap((Map<String, ?>) value, flatMap, key);
} else {
flatMap.put(key, value);
}
}
}

public static class Factory implements Tool.Factory<PainlessScriptTool> {
private ScriptService scriptService;

private static PainlessScriptTool.Factory INSTANCE;

public static PainlessScriptTool.Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (PainlessScriptTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new PainlessScriptTool.Factory();
return INSTANCE;
}
}

public void init(ScriptService scriptService) {
this.scriptService = scriptService;
}

@Override
public PainlessScriptTool create(Map<String, Object> map) {
String script = (String) map.get("script");
if (Strings.isNullOrEmpty(script)) {
throw new IllegalArgumentException("script is required");
}
return new PainlessScriptTool(scriptService, script);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.text.StringEscapeUtils;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;

import com.google.gson.Gson;

/**
* this is a test file to test PainlessTool with junit
*/
public class PainlessScriptToolTests {
@Mock
private ScriptService scriptService;
@Mock
private TemplateScript templateScript;
@Mock
private ActionListener<String> actionListener;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
TemplateScript.Factory factory = new TemplateScript.Factory() {
@Override
public TemplateScript newInstance(Map<String, Object> params) {
return templateScript;
}
};

when(scriptService.compile(any(), any())).thenReturn(factory);

PainlessScriptTool.Factory.getInstance().init(scriptService);
}

@Test
public void testRun() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
when(templateScript.execute()).thenReturn("hello");
tool.run(Map.of(), actionListener);

verify(templateScript).execute();
verify(scriptService).compile(any(), any());
ArgumentCaptor<String> responseCaptor = ArgumentCaptor.forClass(String.class);
verify(actionListener, times(1)).onResponse(responseCaptor.capture());
assertEquals("hello", responseCaptor.getValue());
}

// test run wit exception
@Test
public void testRun_with_exception() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
when(templateScript.execute()).thenThrow(new RuntimeException("error"));
tool.run(Map.of(), actionListener);

verify(templateScript).execute();
verify(scriptService).compile(any(), any());
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("error", exceptionCaptor.getValue().getMessage());
}

// test factory create
@Test
public void testFactory_create() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
assertEquals(PainlessScriptTool.TYPE, tool.getType());
assertEquals("PainlessTool", tool.getName());
assertEquals("Use this tool to execute painless script", tool.getDescription());
}

// test factory create with exception
@Test(expected = IllegalArgumentException.class)
public void testFactory_create_with_exception() {
PainlessScriptTool.Factory.getInstance().create(Map.of());
}

// test flattenMap
@Test
public void testFlattenMap_without_prefix() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
Map<String, Object> map = Map.of("a", Map.of("b", "c"), "k", "v");
Map<String, Object> resultMap = new HashMap<>();
tool.flattenMap(map, resultMap, "");
assertEquals(Map.of("a.b", "c", "k", "v"), resultMap);
}

// with prefix
@Test
public void testFlattenMap_with_prefix() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
Map<String, Object> map = Map.of("a", Map.of("b", "c"), "k", "v");
Map<String, Object> resultMap = new HashMap<>();
tool.flattenMap(map, resultMap, "prefix");
assertEquals(Map.of("prefix.a.b", "c", "prefix.k", "v"), resultMap);
}

// nest map with depth 3
@Test
public void testFlattenMap_with_depth_3() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
Map<String, Object> map = Map.of("a", Map.of("b", Map.of("c", "d"), "k", "v"));
Gson gson = new Gson();
System.out.println(StringEscapeUtils.escapeJson(gson.toJson(map)));
Map<String, Object> resultMap = new HashMap<>();
tool.flattenMap(map, resultMap, "");
assertEquals(Map.of("a.b.c", "d", "a.k", "v"), resultMap);
}

// test getFlattenedParameters
@Test
public void testGetFlattenedParameters() {
String script = "return 'Hello World';";
PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script));
Map<String, String> map = Map.of("k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}");
Map<String, Object> resultMap = tool.getFlattenedParameters(map);
assertEquals(
Map.of("k.a.b.c", "d", "k.a.k", "v", "k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}"),
resultMap
);
}
}
Loading
Loading