Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.16] [Backport] Remove ppl tool execution setting #384

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
Expand All @@ -20,12 +19,9 @@
import org.opensearch.agent.tools.SearchAnomalyResultsTool;
import org.opensearch.agent.tools.SearchMonitorsTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -64,9 +60,7 @@ public Collection<Object> createComponents(
this.client = client;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;
Settings settings = environment.settings();
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
RAGTool.Factory.getInstance().init(client, xContentRegistry);
Expand Down Expand Up @@ -94,8 +88,4 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
);
}

@Override
public List<Setting<?>> getSettings() {
return List.of(SkillSettings.PPL_EXECUTION_ENABLED);
}
}
22 changes: 0 additions & 22 deletions src/main/java/org/opensearch/agent/common/SkillSettings.java

This file was deleted.

23 changes: 3 additions & 20 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
Expand Down Expand Up @@ -98,9 +96,7 @@ public class PPLTool implements Tool {

private int head;

private ClusterSettingHelper clusterSettingHelper;

private static Gson gson = new Gson();
private static Gson gson = org.opensearch.ml.common.utils.StringUtils.gson;

private static Map<String, String> DEFAULT_PROMPT_DICT;

Expand Down Expand Up @@ -153,7 +149,6 @@ public static PPLModelType from(String value) {

public PPLTool(
Client client,
ClusterSettingHelper clusterSettingHelper,
String modelId,
String contextPrompt,
String pplModelType,
Expand All @@ -172,7 +167,6 @@ public PPLTool(
this.previousToolKey = previousToolKey;
this.head = head;
this.execute = execute;
this.clusterSettingHelper = clusterSettingHelper;
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -222,14 +216,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
Map<String, String> dataAsMap = (Map<String, String>) modelTensor.getDataAsMap();
String ppl = parseOutput(dataAsMap.get("response"), indexName);
boolean pplExecutedEnabled = clusterSettingHelper.getClusterSettings(SkillSettings.PPL_EXECUTION_ENABLED);
if (!pplExecutedEnabled || !this.execute) {
if (!pplExecutedEnabled) {
log
.debug(
"PPL execution is disabled, the query will be returned directly, to enable this, please set plugins.skills.ppl_execution_enabled to true"
);
}
if (!this.execute) {
Map<String, String> ret = ImmutableMap.of("ppl", ppl);
listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(ret)));
return;
Expand Down Expand Up @@ -298,8 +285,6 @@ public boolean validate(Map<String, String> parameters) {
public static class Factory implements Tool.Factory<PPLTool> {
private Client client;

private ClusterSettingHelper clusterSettingHelper;

private static Factory INSTANCE;

public static Factory getInstance() {
Expand All @@ -315,17 +300,15 @@ public static Factory getInstance() {
}
}

public void init(Client client, ClusterSettingHelper clusterSettingHelper) {
public void init(Client client) {
this.client = client;
this.clusterSettingHelper = clusterSettingHelper;
}

@Override
public PPLTool create(Map<String, Object> map) {
validatePPLToolParameters(map);
return new PPLTool(
client,
clusterSettingHelper,
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
Expand Down

This file was deleted.

35 changes: 1 addition & 34 deletions src/test/java/org/opensearch/agent/tools/PPLToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand All @@ -26,15 +24,10 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.IndicesAdminClient;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -128,13 +121,7 @@ public void setup() {
listener.onResponse(transportPPLQueryResponse);
return null;
}).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any());

Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), true).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED)));
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool.Factory.getInstance().init(client);
}

@Test
Expand Down Expand Up @@ -413,26 +400,6 @@ public void testTool_executePPLFailure() {
);
}

@Test
public void test_pplTool_whenPPLExecutionDisabled_returnOnlyContainsPPL() {
Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), false).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED)));
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool tool = PPLTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult -> {
Map<String, String> returnResults = gson.fromJson(executePPLResult, Map.class);
assertNull(returnResults.get("executionResult"));
assertEquals("source=demo| head 1", returnResults.get("ppl"));
}, log::error));
}

private void createMappings() {
indexMappings = new HashMap<>();
indexMappings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true);
updateClusterSettings("plugins.skills.ppl_execution_enabled", true);
}

@SneakyThrows
Expand Down
8 changes: 0 additions & 8 deletions src/test/java/org/opensearch/integTest/PPLToolIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ public void testPPLTool() {
);
}

public void test_PPLTool_whenPPLExecutionDisabled_ResultOnlyContainsPPL() {
updateClusterSettings("plugins.skills.ppl_execution_enabled", false);
prepareIndex();
String agentId = registerAgent();
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}");
assertEquals("{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\"}", result);
}

public void testPPLTool_withWrongPPLGenerated_thenThrowException() {
prepareIndex();
String agentId = registerAgent();
Expand Down
Loading