Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance RagTool to choose neural sparse query type #140

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class RAGTool extends AbstractRetrieverTool {
public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String OUTPUT_FIELD = "output_field";
public static final String QUERY_TYPE = "query_type";
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
private String name = TYPE;
private String description = DEFAULT_DESCRIPTION;
private Client client;
Expand All @@ -61,6 +62,7 @@ public class RAGTool extends AbstractRetrieverTool {
private String embeddingField;
private String[] sourceFields;
private String embeddingModelId;
private String queryType;
private Integer docSize;
private Integer k;
@Setter
Expand All @@ -78,7 +80,8 @@ public RAGTool(
Integer k,
Integer docSize,
String embeddingModelId,
String inferenceModelId
String inferenceModelId,
String queryType
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.client = client;
Expand All @@ -90,6 +93,7 @@ public RAGTool(
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
this.k = k == null ? DEFAULT_K : k;
this.inferenceModelId = inferenceModelId;
this.queryType = queryType;
outputParser = new Parser() {
@Override
public Object parse(Object o) {
Expand Down Expand Up @@ -121,16 +125,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
return;
}

Map<String, Object> params = new HashMap<>();
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
params.put(VectorDBTool.INDEX_FIELD, this.index);
params.put(VectorDBTool.EMBEDDING_FIELD, this.embeddingField);
params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(this.sourceFields));
params.put(VectorDBTool.MODEL_ID_FIELD, this.embeddingModelId);
params.put(VectorDBTool.DOC_SIZE_FIELD, String.valueOf(this.docSize));
params.put(VectorDBTool.K_FIELD, String.valueOf(this.k));
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);

String embeddingInput = input;
ActionListener actionListener = ActionListener.<T>wrap(r -> {
T vectorDBToolOutput;
Expand Down Expand Up @@ -193,8 +187,30 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
log.error("Failed to search index.", e);
listener.onFailure(e);
});
vectorDBTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener);

Map<String, Object> params = new HashMap<>();
params.put(VectorDBTool.INDEX_FIELD, this.index);
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
params.put(VectorDBTool.EMBEDDING_FIELD, this.embeddingField);
params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(this.sourceFields));
params.put(VectorDBTool.MODEL_ID_FIELD, this.embeddingModelId);
params.put(VectorDBTool.DOC_SIZE_FIELD, String.valueOf(this.docSize));

switch (this.queryType) {
mingshl marked this conversation as resolved.
Show resolved Hide resolved
case "neural_sparse":
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
neuralSparseSearchTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener);
break;
case "neural":
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
params.put(VectorDBTool.K_FIELD, String.valueOf(this.k));
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);
vectorDBTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener);
break;
default:
log.error("Failed to read queryType, please input neural_sparse or neural.");
listener.onFailure(new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural."));
}
}

@Override
Expand Down Expand Up @@ -256,6 +272,7 @@ public RAGTool create(Map<String, Object> params) {
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2;
String queryType = params.containsKey(QUERY_TYPE) ? (String) params.get(QUERY_TYPE) : "neural";
return RAGTool
.builder()
.client(client)
Expand All @@ -266,6 +283,7 @@ public RAGTool create(Map<String, Object> params) {
.embeddingModelId(embeddingModelId)
.docSize(docSize)
.inferenceModelId(inferenceModelId)
.queryType(queryType)
.build();
}

Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
123 changes: 111 additions & 12 deletions src/test/java/org/opensearch/agent/tools/RAGToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class RAGToolTests {
public static final String TEST_EMBEDDING_FIELD = "test_embedding";
public static final String TEST_EMBEDDING_MODEL_ID = "1234";
public static final String TEST_INFERENCE_MODEL_ID = "1234";
public static final String TEST_NEURAL_QUERY_TYPE = "neural";
public static final String TEST_NEURAL_SPARSE_QUERY_TYPE = "neural_sparse";

public static final String TEST_NEURAL_QUERY = "{\"query\":{\"neural\":{\""
+ TEST_EMBEDDING_FIELD
Expand All @@ -67,6 +69,7 @@ public class RAGToolTests {
private RAGTool ragTool;
private String mockedSearchResponseString;
private String mockedEmptySearchResponseString;
private String mockedNeuralSparseSearchResponseString;
@Mock
private Parser mockOutputParser;
@Mock
Expand All @@ -89,6 +92,11 @@ public void setup() {
}
}

try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("neural_sparse_tool_search_response.json")) {
if (searchResponseIns != null) {
mockedNeuralSparseSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
}
}
client = mock(Client.class);
listener = mock(ActionListener.class);
RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY);
Expand Down Expand Up @@ -142,7 +150,7 @@ public void testGetQueryBodySuccess() {
@Test
public void testOutputParser() throws IOException {

NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);

ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput();
Expand Down Expand Up @@ -175,7 +183,7 @@ public void testOutputParser() throws IOException {

@Test
public void testRunWithEmptySearchResponse() throws IOException {
NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);

ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput();
Expand Down Expand Up @@ -203,9 +211,86 @@ public void testRunWithEmptySearchResponse() throws IOException {
verify(client).execute(any(), any(), any());
}

@Test
public void testRunWithNeuralSparseQueryType() throws IOException {
RAGTool rAGtoolWithNeuralSparseQuery = new RAGTool(
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY,
TEST_INDEX,
TEST_EMBEDDING_FIELD,
TEST_SOURCE_FIELDS,
DEFAULT_K,
TEST_DOC_SIZE,
TEST_EMBEDDING_MODEL_ID,
TEST_INFERENCE_MODEL_ID,
TEST_NEURAL_SPARSE_QUERY_TYPE
);
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
rAGtoolWithNeuralSparseQuery.setXContentRegistry(mockNamedXContentRegistry);

ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput();
SearchResponse mockedNeuralSparseSearchResponse = SearchResponse
.fromXContent(
JsonXContent.jsonXContent
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.IGNORE_DEPRECATIONS,
mockedNeuralSparseSearchResponseString
)
);

doAnswer(invocation -> {
SearchRequest searchRequest = invocation.getArgument(0);
assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size());
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(mockedNeuralSparseSearchResponse);
return null;
}).when(client).search(any(), any());

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
rAGtoolWithNeuralSparseQuery.run(Map.of(INPUT_FIELD, "hello?"), listener);
verify(client).search(any(), any());
verify(client).execute(any(), any(), any());
}

@Test
public void testRunWithInvalidQueryType() throws IOException {
RAGTool rAGtoolWithInvalidQueryType = new RAGTool(
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY,
TEST_INDEX,
TEST_EMBEDDING_FIELD,
TEST_SOURCE_FIELDS,
DEFAULT_K,
TEST_DOC_SIZE,
TEST_EMBEDDING_MODEL_ID,
TEST_INFERENCE_MODEL_ID,
"sparse"
);

doAnswer(invocation -> {
SearchRequest searchRequest = invocation.getArgument(0);
assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size());
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural."));
return null;
}).when(client).search(any(), any());

rAGtoolWithInvalidQueryType.run(Map.of(INPUT_FIELD, "hello?"), listener);
verify(listener).onFailure(any(IllegalArgumentException.class));
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assertEquals("Failed to read queryType, please input neural_sparse or neural.", argumentCaptor.getValue().getMessage());

}

@Test
public void testRunWithQuestionJson() throws IOException {
NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);

ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput();
Expand Down Expand Up @@ -236,7 +321,7 @@ public void testRunWithQuestionJson() throws IOException {
@Test
@SneakyThrows
public void testRunWithRuntimeExceptionDuringSearch() {
NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);
doAnswer(invocation -> {
SearchRequest searchRequest = invocation.getArgument(0);
Expand All @@ -255,7 +340,7 @@ public void testRunWithRuntimeExceptionDuringSearch() {
@Test
@SneakyThrows
public void testRunWithRuntimeExceptionDuringExecute() {
NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);

SearchResponse mockedSearchResponse = SearchResponse
Expand Down Expand Up @@ -311,7 +396,8 @@ public void testFactory() {
DEFAULT_K,
TEST_DOC_SIZE,
TEST_EMBEDDING_MODEL_ID,
TEST_INFERENCE_MODEL_ID
TEST_INFERENCE_MODEL_ID,
TEST_NEURAL_QUERY_TYPE
);

assertEquals(rAGtool1.getClient(), rAGtool2.getClient());
Expand All @@ -327,15 +413,28 @@ public void testFactory() {

}

private static NamedXContentRegistry getNeuralQueryNamedXContentRegistry() {
private static NamedXContentRegistry getQueryNamedXContentRegistry() {
QueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder();

List<NamedXContentRegistry.Entry> entries = new ArrayList<>();
NamedXContentRegistry.Entry entry = new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField("neural"), (p, c) -> {
p.map();
return matchAllQueryBuilder;
});
entries.add(entry);
NamedXContentRegistry.Entry neural_query_entry = new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField("neural"),
(p, c) -> {
p.map();
return matchAllQueryBuilder;
}
);
entries.add(neural_query_entry);
NamedXContentRegistry.Entry neural_sparse_query_entry = new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField("neural_sparse"),
(p, c) -> {
p.map();
return matchAllQueryBuilder;
}
);
entries.add(neural_sparse_query_entry);
NamedXContentRegistry mockNamedXContentRegistry = new NamedXContentRegistry(entries);
return mockNamedXContentRegistry;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void testCreateTool() {
assertEquals(TEST_K, tool.getK());
assertEquals("VectorDBTool", tool.getType());
assertEquals("VectorDBTool", tool.getName());
assertEquals("Use this tool to search data in OpenSearch index.", VectorDBTool.Factory.getInstance().getDefaultDescription());
assertEquals(VectorDBTool.DEFAULT_DESCRIPTION, VectorDBTool.Factory.getInstance().getDefaultDescription());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
{
"took" : 688,
"timed_out" : false,
"_shards" : {
"total" : 1,
"successful" : 1,
"skipped" : 0,
"failed" : 0
},
"hits" : {
"total" : {
"value" : 2,
"relation" : "eq"
},
"max_score" : 30.0029,
"hits" : [
{
"_index" : "my-nlp-index",
"_id" : "1",
"_score" : 30.0029,
"_source" : {
"passage_text" : "Hello world",
"passage_embedding" : {
"!" : 0.8708904,
"door" : 0.8587369,
"hi" : 2.3929274,
"worlds" : 2.7839446,
"yes" : 0.75845814,
"##world" : 2.5432441,
"born" : 0.2682308,
"nothing" : 0.8625516,
"goodbye" : 0.17146169,
"greeting" : 0.96817183,
"birth" : 1.2788506,
"come" : 0.1623208,
"global" : 0.4371151,
"it" : 0.42951578,
"life" : 1.5750692,
"thanks" : 0.26481047,
"world" : 4.7300377,
"tiny" : 0.5462298,
"earth" : 2.6555297,
"universe" : 2.0308156,
"worldwide" : 1.3903781,
"hello" : 6.696973,
"so" : 0.20279501,
"?" : 0.67785245
},
"id" : "s1"
}
},
{
"_index" : "my-nlp-index",
"_id" : "2",
"_score" : 16.480486,
"_source" : {
"passage_text" : "Hi planet",
"passage_embedding" : {
"hi" : 4.338913,
"planets" : 2.7755864,
"planet" : 5.0969057,
"mars" : 1.7405145,
"earth" : 2.6087382,
"hello" : 3.3210192
},
"id" : "s2"
}
}
]
}
}
Loading