Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance RagTool to choose neural sparse query type #140

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 31 additions & 74 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.agent.tools;

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

Expand All @@ -21,10 +21,10 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
Expand All @@ -44,29 +44,28 @@
@Setter
@Getter
@ToolAnnotation(RAGTool.TYPE)
public class RAGTool extends AbstractRetrieverTool {
public class RAGTool implements Tool {
public static final String TYPE = "RAGTool";
public static String DEFAULT_DESCRIPTION =
"Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions.";
public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id";
public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String OUTPUT_FIELD = "output_field";
public static final String QUERY_TYPE = "query_type";
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
public static final String CONTENT_GENERATION_FIELD = "enable_Content_Generation";
public static final String K_FIELD = "k";
private final AbstractRetrieverTool queryTool;
private String name = TYPE;
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String inferenceModelId;
private Boolean enableContentGeneration;
private NamedXContentRegistry xContentRegistry;
private String index;
private String embeddingField;
private String[] sourceFields;
private String embeddingModelId;
private String queryType;
private Integer docSize;
private Integer k;
@Setter
private Parser inputParser;
@Setter
Expand All @@ -76,27 +75,14 @@ public class RAGTool extends AbstractRetrieverTool {
public RAGTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String embeddingField,
String[] sourceFields,
Integer k,
Integer docSize,
String embeddingModelId,
String inferenceModelId,
String queryType,
Boolean enableContentGeneration,
AbstractRetrieverTool queryTool
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.embeddingField = embeddingField;
this.sourceFields = sourceFields;
this.embeddingModelId = embeddingModelId;
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
this.k = k == null ? DEFAULT_K : k;
this.inferenceModelId = inferenceModelId;
this.queryType = queryType;
this.enableContentGeneration = enableContentGeneration;
this.queryTool = queryTool;
outputParser = new Parser() {
@Override
Expand All @@ -107,13 +93,6 @@ public Object parse(Object o) {
};
}

// getQueryBody is not used in RAGTool
@Override
protected String getQueryBody(String queryText) {
return queryText;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
String input = null;

Expand All @@ -132,7 +111,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
String embeddingInput = input;
ActionListener actionListener = ActionListener.<T>wrap(r -> {
T queryToolOutput;

if (!this.enableContentGeneration) {
listener.onResponse(r);
}
if (r.equals("Can not get any match from search result.")) {
queryToolOutput = (T) "";
} else {
Expand All @@ -155,25 +136,15 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
Map<String, String> tmpParameters = new HashMap<>();
tmpParameters.putAll(parameters);

if (queryToolOutput instanceof List
&& !((List) queryToolOutput).isEmpty()
&& ((List) queryToolOutput).get(0) instanceof ModelTensors) {
ModelTensors tensors = (ModelTensors) ((List) queryToolOutput).get(0);
Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response");
tmpParameters.put(OUTPUT_FIELD, response + "");
} else if (queryToolOutput instanceof ModelTensor) {
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(((ModelTensor) queryToolOutput).getDataAsMap())));
if (queryToolOutput instanceof String) {
tmpParameters.put(OUTPUT_FIELD, (String) queryToolOutput);
} else {
if (queryToolOutput instanceof String) {
tmpParameters.put(OUTPUT_FIELD, (String) queryToolOutput);
} else {
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(queryToolOutput.toString())));
}
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(queryToolOutput.toString())));
}

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
ActionRequest request = new MLPredictionTaskRequest(inferenceModelId, mlInput, null);
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput();
Expand All @@ -184,32 +155,33 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
}
}, e -> {
log.error("Failed to run model " + inferenceModelId, e);
log.error("Failed to run model " + this.inferenceModelId, e);
listener.onFailure(e);
}));
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
this.queryTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener);
this.queryTool.run(Map.of(INPUT_FIELD, embeddingInput), actionListener);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also remove the extends logic in below lines. public static class Factory extends AbstractRetrieverTool.Factory<RAGTool>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used Factory implements Tool.Factory

public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
Expand Down Expand Up @@ -247,15 +219,12 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) {

@Override
public RAGTool create(Map<String, Object> params) {
String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD);
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2;
String queryType = params.containsKey(QUERY_TYPE) ? (String) params.get(QUERY_TYPE) : "neural";
Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K;
;
String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD);
Boolean enableContentGeneration = params.containsKey(CONTENT_GENERATION_FIELD)
? Boolean.parseBoolean((String) params.get(CONTENT_GENERATION_FIELD))
: true;
String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : "";
switch (queryType) {
case "neural_sparse":
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
Expand All @@ -264,32 +233,20 @@ public RAGTool create(Map<String, Object> params) {
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.embeddingModelId(embeddingModelId)
.k(k)
.docSize(docSize)
.inferenceModelId(inferenceModelId)
.queryType(queryType)
.enableContentGeneration(enableContentGeneration)
.queryTool(neuralSparseSearchTool)
.build();
case "neural":
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
params.put(VectorDBTool.MODEL_ID_FIELD, EMBEDDING_MODEL_ID_FIELD);
params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also do this for NeuralSparseTool

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to call VectorDBTool.Factory.getInstance().init(client, xContentRegistry) or NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry) here. This has been done during the initialization of skills plugin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the initiation to the test to simulate the initiations of tools in ToolPlugins.java

mapped the inference_model_id to model_id for NeuralSparseTool

VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);
return RAGTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.embeddingModelId(embeddingModelId)
.k(k)
.docSize(docSize)
.inferenceModelId(inferenceModelId)
.queryType(queryType)
.enableContentGeneration(enableContentGeneration)
.queryTool(vectorDBTool)
.build();
default:
Expand Down
Loading
Loading