Skip to content

Commit

Permalink
feat: search index tool (opensearch-project#61) (opensearch-project#77)
Browse files Browse the repository at this point in the history
* add search index tool

* run spotless apply

* remove unncessary string util operation

* add test cases

* spotless apply

* update tool description and add model group search

---------

(cherry picked from commit 2e05330)

Signed-off-by: yuye-aws <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent b27a664 commit 7ad82f6
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.SearchIndexTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.VisualizationsTool;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -60,6 +61,7 @@ public Collection<Object> createComponents(
VisualizationsTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ protected AbstractRetrieverTool(

protected abstract String getQueryBody(String queryText);

public static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
Expand Down Expand Up @@ -98,13 +107,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
for (SearchHit hit : hits) {
Map<String, Object> docContent = processResponse(hit);
contextBuilder.append(gson.toJson(docContent)).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
Expand Down
172 changes: 172 additions & 0 deletions src/main/java/org/opensearch/agent/tools/SearchIndexTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.*;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonObject;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Getter
@Setter
@Log4j2
@ToolAnnotation(SearchIndexTool.TYPE)
public class SearchIndexTool implements Tool {

public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String QUERY_FIELD = "query";

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query.";

private String name = TYPE;

private String description = DEFAULT_DESCRIPTION;

private Client client;

private NamedXContentRegistry xContentRegistry;

public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = jsonObject.get(INDEX_FIELD).getAsString();
String query = jsonObject.get(QUERY_FIELD).toString();
query = "{\"query\": " + query + "}";

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit);
return StringUtils.gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "");
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});

// since searching connector and model needs access control, we need
// to forward the request corresponding transport action
if (Objects.equals(index, ML_CONNECTOR_INDEX)) {
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_INDEX)) {
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) {
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener);
} else {
client.search(searchRequest, actionListener);
}
} catch (Exception e) {
log.error("Failed to search index", e);
listener.onFailure(e);
}
}

public static class Factory implements Tool.Factory<SearchIndexTool> {

private Client client;
private static Factory INSTANCE;

private NamedXContentRegistry xContentRegistry;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchIndexTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public SearchIndexTool create(Map<String, Object> params) {
return new SearchIndexTool(client, xContentRegistry);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public void testRunAsyncWithSearchResults() {
future.join();
assertEquals(
"{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n"
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
future.get()
);
}
Expand Down Expand Up @@ -218,7 +218,7 @@ public void testFactory() {
// Create a mock object of the abstract Factory class
Client client = mock(Client.class);
AbstractRetrieverTool.Factory<Tool> factoryMock = new AbstractRetrieverTool.Factory<>() {
public PPLTool create(Map<String, Object> params) {
public AbstractRetrieverTool create(Map<String, Object> params) {
return null;
}
};
Expand Down
Loading

0 comments on commit 7ad82f6

Please sign in to comment.