Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] feat: search index tool #77

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.SearchIndexTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.VisualizationsTool;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -60,6 +61,7 @@ public Collection<Object> createComponents(
VisualizationsTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ protected AbstractRetrieverTool(

protected abstract String getQueryBody(String queryText);

public static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
Expand Down Expand Up @@ -98,13 +107,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
for (SearchHit hit : hits) {
Map<String, Object> docContent = processResponse(hit);
contextBuilder.append(gson.toJson(docContent)).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
Expand Down
172 changes: 172 additions & 0 deletions src/main/java/org/opensearch/agent/tools/SearchIndexTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.*;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonObject;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Getter
@Setter
@Log4j2
@ToolAnnotation(SearchIndexTool.TYPE)
public class SearchIndexTool implements Tool {

public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String QUERY_FIELD = "query";

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query.";

private String name = TYPE;

private String description = DEFAULT_DESCRIPTION;

private Client client;

private NamedXContentRegistry xContentRegistry;

public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = jsonObject.get(INDEX_FIELD).getAsString();
String query = jsonObject.get(QUERY_FIELD).toString();
query = "{\"query\": " + query + "}";

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit);
return StringUtils.gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "");
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});

// since searching connector and model needs access control, we need
// to forward the request corresponding transport action
if (Objects.equals(index, ML_CONNECTOR_INDEX)) {
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_INDEX)) {
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) {
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener);
} else {
client.search(searchRequest, actionListener);
}
} catch (Exception e) {
log.error("Failed to search index", e);
listener.onFailure(e);
}
}

public static class Factory implements Tool.Factory<SearchIndexTool> {

private Client client;
private static Factory INSTANCE;

private NamedXContentRegistry xContentRegistry;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchIndexTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public SearchIndexTool create(Map<String, Object> params) {
return new SearchIndexTool(client, xContentRegistry);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public void testRunAsyncWithSearchResults() {
future.join();
assertEquals(
"{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n"
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
future.get()
);
}
Expand Down Expand Up @@ -218,7 +218,7 @@ public void testFactory() {
// Create a mock object of the abstract Factory class
Client client = mock(Client.class);
AbstractRetrieverTool.Factory<Tool> factoryMock = new AbstractRetrieverTool.Factory<>() {
public PPLTool create(Map<String, Object> params) {
public AbstractRetrieverTool create(Map<String, Object> params) {
return null;
}
};
Expand Down
Loading
Loading