-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: search index tool #1750
feat: search index tool #1750
Changes from all commits
d73d3a6
ca9d7a9
3fcf57f
eb7388c
dc22288
93633d8
6d248bc
97f6d8c
82bc27f
b5e5059
06e4699
35bbf24
16a6a87
fe39362
25d0948
b3ad0ca
d88bf20
5dcd28f
d0c9843
9d1bab8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; | ||
|
||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; | ||
import org.opensearch.ml.common.transport.model.MLModelSearchAction; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import com.google.gson.JsonObject; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Getter | ||
@Setter | ||
@Log4j2 | ||
@ToolAnnotation(SearchIndexTool.TYPE) | ||
public class SearchIndexTool implements Tool { | ||
|
||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String QUERY_FIELD = "query"; | ||
|
||
public static final String TYPE = "SearchIndexTool"; | ||
private static final String DEFAULT_DESCRIPTION = | ||
"Use this tool to search index with a query. You should pass in two parameters: index and query. Index is the index name and the query is an OpenSearch DSL query."; | ||
|
||
private String name = TYPE; | ||
|
||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
private Client client; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getVersion() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
try { | ||
String input = parameters.get(INPUT_FIELD); | ||
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); | ||
String index = jsonObject.get(INDEX_FIELD).getAsString(); | ||
String query = jsonObject.get(QUERY_FIELD).toString(); | ||
query = "{\"query\": " + query + "}"; | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON | ||
.xContent() | ||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
|
||
ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (SearchHit hit : hits) { | ||
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> { | ||
Map<String, Object> docContent = AbstractRetrieverTool.processResponse(hit); | ||
return StringUtils.gson.toJson(docContent); | ||
}); | ||
contextBuilder.append(doc).append("\n"); | ||
} | ||
listener.onResponse((T) StringUtils.gson.toJson(contextBuilder.toString())); | ||
} else { | ||
listener.onResponse((T) ""); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
}); | ||
|
||
// since searching connector and model needs access control, we need | ||
// to forward the request corresponding transport action | ||
if (Objects.equals(index, ML_CONNECTOR_INDEX)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add searching connector and model index in this tool? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this tool responsibility clear: it only support searching index directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When the target index is ML Connector or ML Model, it requires additional access control to perform search operation. Therefore, I need to forward the request to corresponding transport action. I will add some comment here to clarify it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ylwu-amzn If user use this tool only to create an agent, then use this agent to ask questions related to model/connector index, then user is able to bypass the permission control. And here we're missing another permission controlled index: model_group. |
||
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else if (Objects.equals(index, ML_MODEL_INDEX)) { | ||
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else { | ||
client.search(searchRequest, actionListener); | ||
} | ||
} catch (Exception e) { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
public static class Factory implements Tool.Factory<SearchIndexTool> { | ||
|
||
private Client client; | ||
private static Factory INSTANCE; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
/** | ||
* Create or return the singleton factory instance | ||
*/ | ||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (MLModelTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public SearchIndexTool create(Map<String, Object> params) { | ||
return new SearchIndexTool(client, xContentRegistry); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two exceptions can be merged together, e.g.
cause instanceof MLException || cause instanceof IllegalArgumentException
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLException and IllegalArgumentException are both runtime exceptions. We treat them separately so that we do not need to explicitly declare throw Throwable in function signature.