-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add search index tool Signed-off-by: yuye-aws <[email protected]> * run spotless apply Signed-off-by: yuye-aws <[email protected]> * remove unncessary string util operation Signed-off-by: yuye-aws <[email protected]> * add test cases Signed-off-by: yuye-aws <[email protected]> * spotless apply Signed-off-by: yuye-aws <[email protected]> * update tool description and add model group search Signed-off-by: yuye-aws <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]> (cherry picked from commit 2e05330) Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b0db422
commit 4ffc2b7
Showing
6 changed files
with
357 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
172 changes: 172 additions & 0 deletions
172
src/main/java/org/opensearch/agent/tools/SearchIndexTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.CommonValue.*; | ||
|
||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; | ||
import org.opensearch.ml.common.transport.model.MLModelSearchAction; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import com.google.gson.JsonObject; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Getter | ||
@Setter | ||
@Log4j2 | ||
@ToolAnnotation(SearchIndexTool.TYPE) | ||
public class SearchIndexTool implements Tool { | ||
|
||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String QUERY_FIELD = "query"; | ||
|
||
public static final String TYPE = "SearchIndexTool"; | ||
private static final String DEFAULT_DESCRIPTION = | ||
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query."; | ||
|
||
private String name = TYPE; | ||
|
||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
private Client client; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getVersion() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
try { | ||
String input = parameters.get(INPUT_FIELD); | ||
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); | ||
String index = jsonObject.get(INDEX_FIELD).getAsString(); | ||
String query = jsonObject.get(QUERY_FIELD).toString(); | ||
query = "{\"query\": " + query + "}"; | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON | ||
.xContent() | ||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
|
||
ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (SearchHit hit : hits) { | ||
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> { | ||
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit); | ||
return StringUtils.gson.toJson(docContent); | ||
}); | ||
contextBuilder.append(doc).append("\n"); | ||
} | ||
listener.onResponse((T) contextBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) ""); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
}); | ||
|
||
// since searching connector and model needs access control, we need | ||
// to forward the request corresponding transport action | ||
if (Objects.equals(index, ML_CONNECTOR_INDEX)) { | ||
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else if (Objects.equals(index, ML_MODEL_INDEX)) { | ||
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) { | ||
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else { | ||
client.search(searchRequest, actionListener); | ||
} | ||
} catch (Exception e) { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
public static class Factory implements Tool.Factory<SearchIndexTool> { | ||
|
||
private Client client; | ||
private static Factory INSTANCE; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
/** | ||
* Create or return the singleton factory instance | ||
*/ | ||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (SearchIndexTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public SearchIndexTool create(Map<String, Object> params) { | ||
return new SearchIndexTool(client, xContentRegistry); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.