Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: search index tool #1750

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ private static <T, S> S init(Map<T, Class<?>> map, T type,
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
} else if (cause instanceof IllegalArgumentException) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two exceptions can be merged together, e.g. cause instanceof MLException || cause instanceof IllegalArgumentException.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLException and IllegalArgumentException are both runtime exceptions. We treat them separately so that we do not need to explicitly declare throw Throwable in function signature.

throw (MLException) cause;
} else if (cause instanceof IllegalArgumentException) {
throw (IllegalArgumentException) cause;
} else {
log.error("Failed to init instance for type " + type, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ protected AbstractRetrieverTool(

protected abstract String getQueryBody(String queryText);

public static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
Expand All @@ -93,19 +102,14 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
Map<String, Object> docContent = processResponse(hit);
return gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonObject;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Getter
@Setter
@Log4j2
@ToolAnnotation(SearchIndexTool.TYPE)
public class SearchIndexTool implements Tool {

public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String QUERY_FIELD = "query";

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search index with a query. You should pass in two parameters: index and query. Index is the index name and the query is an OpenSearch DSL query.";

private String name = TYPE;

private String description = DEFAULT_DESCRIPTION;

private Client client;

private NamedXContentRegistry xContentRegistry;

public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = jsonObject.get(INDEX_FIELD).getAsString();
String query = jsonObject.get(QUERY_FIELD).toString();
query = "{\"query\": " + query + "}";

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit);
return StringUtils.gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) StringUtils.gson.toJson(contextBuilder.toString()));
} else {
listener.onResponse((T) "");
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});

// since searching connector and model needs access control, we need
// to forward the request corresponding transport action
if (Objects.equals(index, ML_CONNECTOR_INDEX)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add searching connector and model index in this tool?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this tool responsibility clear: it only support searching index directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add searching connector and model index in this tool?

When the target index is ML Connector or ML Model, it requires additional access control to perform search operation. Therefore, I need to forward the request to corresponding transport action. I will add some comment here to clarify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylwu-amzn If user use this tool only to create an agent, then use this agent to ask questions related to model/connector index, then user is able to bypass the permission control. And here we're missing another permission controlled index: model_group.

client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_INDEX)) {
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener);
} else {
client.search(searchRequest, actionListener);
}
} catch (Exception e) {
log.error("Failed to search index", e);
listener.onFailure(e);
}
}

public static class Factory implements Tool.Factory<SearchIndexTool> {

private Client client;
private static Factory INSTANCE;

private NamedXContentRegistry xContentRegistry;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (MLModelTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public SearchIndexTool create(Map<String, Object> params) {
return new SearchIndexTool(client, xContentRegistry);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void testRunAsyncWithSearchResults() {
future.join();
assertEquals(
"{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n"
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
gson.fromJson(future.get(), String.class)
);
}
Expand Down Expand Up @@ -133,27 +133,23 @@ public void testRunAsyncWithEmptySearchResponse() {
}

@Test
@SneakyThrows
public void testRunAsyncWithIllegalQueryThenThrowException() {
Client client = mock(Client.class);
mockedImpl.setClient(client);

assertThrows(
"[input] is null or empty, can not process it.",
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""), null)
);
assertEquals("[input] is null or empty, can not process it.", exception.getMessage());

assertThrows(
"[input] is null or empty, can not process it.",
exception = assertThrows(
IllegalArgumentException.class,
() -> mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "), null)
);
assertEquals("[input] is null or empty, can not process it.", exception.getMessage());

assertThrows(
"[input] is null or empty, can not process it.",
IllegalArgumentException.class,
() -> mockedImpl.run(Map.of("test", "hello world"), null)
);
exception = assertThrows(IllegalArgumentException.class, () -> mockedImpl.run(Map.of("test", "hello world"), null));
assertEquals("[input] is null or empty, can not process it.", exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public void testGetQueryBodyWithIllegalParams() {
);

Map<String, Object> illegalParams2 = new HashMap<>(params);
illegalParams1.remove(NeuralSparseTool.EMBEDDING_FIELD);
NeuralSparseTool tool2 = NeuralSparseTool.Factory.getInstance().create(illegalParams1);
illegalParams2.remove(NeuralSparseTool.EMBEDDING_FIELD);
NeuralSparseTool tool2 = NeuralSparseTool.Factory.getInstance().create(illegalParams2);
assertThrows(
"Parameter [embedding_field] and [model_id] can not be null or empty.",
IllegalArgumentException.class,
Expand Down
Loading
Loading