Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: search index tool #61

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.SearchIndexTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -58,6 +59,7 @@
PPLTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);

Check warning on line 62 in src/main/java/org/opensearch/agent/ToolPlugin.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/ToolPlugin.java#L62

Added line #L62 was not covered by tests
return Collections.emptyList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ protected AbstractRetrieverTool(

protected abstract String getQueryBody(String queryText);

public static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
Expand Down Expand Up @@ -98,13 +107,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
for (SearchHit hit : hits) {
Map<String, Object> docContent = processResponse(hit);
contextBuilder.append(gson.toJson(docContent)).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
Expand Down
172 changes: 172 additions & 0 deletions src/main/java/org/opensearch/agent/tools/SearchIndexTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.*;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonObject;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Getter
@Setter
@Log4j2
@ToolAnnotation(SearchIndexTool.TYPE)
public class SearchIndexTool implements Tool {

public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String QUERY_FIELD = "query";

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query.";

private String name = TYPE;

private String description = DEFAULT_DESCRIPTION;

private Client client;

private NamedXContentRegistry xContentRegistry;

public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;

Check warning on line 72 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L72

Added line #L72 was not covered by tests
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = jsonObject.get(INDEX_FIELD).getAsString();
String query = jsonObject.get(QUERY_FIELD).toString();
query = "{\"query\": " + query + "}";

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit);
return StringUtils.gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "");

Check warning on line 110 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L110

Added line #L110 was not covered by tests
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});

Check warning on line 115 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L113-L115

Added lines #L113 - L115 were not covered by tests

// since searching connector and model needs access control, we need
// to forward the request corresponding transport action
if (Objects.equals(index, ML_CONNECTOR_INDEX)) {
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_INDEX)) {
zane-neo marked this conversation as resolved.
Show resolved Hide resolved
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) {
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener);

Check warning on line 124 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L124

Added line #L124 was not covered by tests
} else {
client.search(searchRequest, actionListener);
}
} catch (Exception e) {
log.error("Failed to search index", e);
listener.onFailure(e);
}
}

public static class Factory implements Tool.Factory<SearchIndexTool> {

private Client client;
private static Factory INSTANCE;

private NamedXContentRegistry xContentRegistry;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;

Check warning on line 146 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L146

Added line #L146 was not covered by tests
}
synchronized (SearchIndexTool.class) {
if (INSTANCE != null) {
return INSTANCE;

Check warning on line 150 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L150

Added line #L150 was not covered by tests
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

Check warning on line 160 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L158-L160

Added lines #L158 - L160 were not covered by tests

@Override
public SearchIndexTool create(Map<String, Object> params) {
return new SearchIndexTool(client, xContentRegistry);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;

Check warning on line 169 in src/main/java/org/opensearch/agent/tools/SearchIndexTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/SearchIndexTool.java#L169

Added line #L169 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public void testRunAsyncWithSearchResults() {
future.join();
assertEquals(
"{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n"
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
+ "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n",
future.get()
);
}
Expand Down Expand Up @@ -218,7 +218,7 @@ public void testFactory() {
// Create a mock object of the abstract Factory class
Client client = mock(Client.class);
AbstractRetrieverTool.Factory<Tool> factoryMock = new AbstractRetrieverTool.Factory<>() {
public PPLTool create(Map<String, Object> params) {
public AbstractRetrieverTool create(Map<String, Object> params) {
return null;
}
};
Expand Down
Loading
Loading