Skip to content

Commit

Permalink
[Feature] support nested query in neural sparse tool, vectorDB tool a…
Browse files Browse the repository at this point in the history
…nd RAG tool (#350)

* support nested query in neural sparse

Signed-off-by: zhichao-aws <[email protected]>

* support nested in vector db tool

Signed-off-by: zhichao-aws <[email protected]>

* add test for RAG tool pass nested path

Signed-off-by: zhichao-aws <[email protected]>

* keep the 1st digit for score

Signed-off-by: zhichao-aws <[email protected]>

* lint

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Jul 16, 2024
1 parent 14d9ef2 commit 7a5d0d8
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String NESTED_PATH_FIELD = "nested_path";

private String name = TYPE;
private String modelId;
private String embeddingField;
private String nestedPath;

@Builder
public NeuralSparseSearchTool(
Expand All @@ -46,11 +48,13 @@ public NeuralSparseSearchTool(
String embeddingField,
String[] sourceFields,
Integer docSize,
String modelId
String modelId,
String nestedPath
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
this.nestedPath = nestedPath;
}

@Override
Expand All @@ -61,8 +65,29 @@ protected String getQueryBody(String queryText) {
);
}

Map<String, Object> queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
Map<String, Object> queryBody;
if (StringUtils.isBlank(nestedPath)) {
queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
} else {
queryBody = Map
.of(
"query",
Map
.of(
"nested",
Map
.of(
"path",
nestedPath,
"score_mode",
"max",
"query",
Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId)))
)
)
);
}

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
Expand Down Expand Up @@ -99,6 +124,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return NeuralSparseSearchTool
.builder()
.client(client)
Expand All @@ -108,6 +134,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.nestedPath(nestedPath)
.build();
}

Expand Down
34 changes: 31 additions & 3 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ public class VectorDBTool extends AbstractRetrieverTool {
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String K_FIELD = "k";
public static final Integer DEFAULT_K = 10;
public static final String NESTED_PATH_FIELD = "nested_path";

private String name = TYPE;
private String modelId;
private String embeddingField;
private Integer k;
private String nestedPath;

@Builder
public VectorDBTool(
Expand All @@ -53,12 +55,14 @@ public VectorDBTool(
String[] sourceFields,
Integer docSize,
String modelId,
Integer k
Integer k,
String nestedPath
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
this.k = k;
this.nestedPath = nestedPath;
}

@Override
Expand All @@ -69,8 +73,30 @@ protected String getQueryBody(String queryText) {
);
}

Map<String, Object> queryBody = Map
.of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))));
Map<String, Object> queryBody;
if (StringUtils.isBlank(nestedPath)) {
queryBody = Map
.of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))));

} else {
queryBody = Map
.of(
"query",
Map
.of(
"nested",
Map
.of(
"path",
nestedPath,
"score_mode",
"max",
"query",
Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k)))
)
)
);
}

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
Expand Down Expand Up @@ -108,6 +134,7 @@ public VectorDBTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return VectorDBTool
.builder()
.client(client)
Expand All @@ -118,6 +145,7 @@ public VectorDBTool create(Map<String, Object> params) {
.modelId(modelId)
.docSize(docSize)
.k(k)
.nestedPath(nestedPath)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class NeuralSparseSearchToolTests {
public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh";
public static final String TEST_EMBEDDING_FIELD = "test embedding";
public static final String TEST_MODEL_ID = "123fsd23134";
public static final String TEST_NESTED_PATH = "nested_path";
private Map<String, Object> params = new HashMap<>();

@Before
Expand Down Expand Up @@ -60,6 +61,22 @@ public void testGetQueryBody() {
assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithNestedPath() {
params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params);
Map<String, Map<String, Map<String, Object>>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path"));
assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode"));
Map<String, Map<String, Map<String, String>>> queryBody = (Map<String, Map<String, Map<String, String>>>) nestedQueryBody
.get("query")
.get("nested")
.get("query");
assertEquals("123fsd23134sdfouh", queryBody.get("neural_sparse").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
Expand Down Expand Up @@ -110,6 +127,11 @@ public void testCreateToolsParseParams() {
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123))
);

assertThrows(
ClassCastException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.NESTED_PATH_FIELD, 123))
);

assertThrows(
JsonSyntaxException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.SOURCE_FIELD, "123"))
Expand Down
10 changes: 9 additions & 1 deletion src/test/java/org/opensearch/agent/tools/RAGToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class RAGToolTests {
public static final String TEST_INFERENCE_MODEL_ID = "1234";
public static final String TEST_NEURAL_QUERY_TYPE = "neural";
public static final String TEST_NEURAL_SPARSE_QUERY_TYPE = "neural_sparse";
public static final String TEST_NESTED_PATH = "nested_path";

static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY = getQueryNamedXContentRegistry();
private RAGTool ragTool;
Expand Down Expand Up @@ -422,6 +423,7 @@ public void testFactoryNeuralQuery() {
assertEquals(factoryMock.getDefaultVersion(), null);
assertNotNull(RAGTool.Factory.getInstance());

params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
RAGTool rAGtool1 = factoryMock.create(params);
VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY);
params.put(VectorDBTool.MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID);
Expand All @@ -436,6 +438,7 @@ public void testFactoryNeuralQuery() {
assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields());
assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry());
assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType());
assertEquals(((VectorDBTool) rAGtool1.getQueryTool()).getNestedPath(), ((VectorDBTool) rAGtool2.getQueryTool()).getNestedPath());
}

@Test
Expand All @@ -450,6 +453,8 @@ public void testFactoryNeuralSparseQuery() {
assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE);
assertEquals(factoryMock.getDefaultVersion(), null);

params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
params.put("query_type", "neural_sparse");
RAGTool rAGtool1 = factoryMock.create(params);
NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY);
NeuralSparseSearchTool queryTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
Expand All @@ -463,7 +468,10 @@ public void testFactoryNeuralSparseQuery() {
assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields());
assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry());
assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType());

assertEquals(
((NeuralSparseSearchTool) rAGtool1.getQueryTool()).getNestedPath(),
((NeuralSparseSearchTool) rAGtool2.getQueryTool()).getNestedPath()
);
}

private static NamedXContentRegistry getQueryNamedXContentRegistry() {
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class VectorDBToolTests {
public static final String TEST_EMBEDDING_FIELD = "test embedding";
public static final String TEST_MODEL_ID = "123fsd23134";
public static final Integer TEST_K = 123;
public static final String TEST_NESTED_PATH = "nested_path";
private Map<String, Object> params = new HashMap<>();

@Before
Expand Down Expand Up @@ -61,6 +62,22 @@ public void testGetQueryBody() {
assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithNestedPath() {
params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params);
Map<String, Map<String, Map<String, Object>>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path"));
assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode"));
Map<String, Map<String, Map<String, String>>> queryBody = (Map<String, Map<String, Map<String, String>>>) nestedQueryBody
.get("query")
.get("nested")
.get("query");
assertEquals("123fsd23134sdfouh", queryBody.get("neural").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("neural").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
Expand Down Expand Up @@ -103,6 +120,11 @@ public void testCreateToolsParseParams() {

assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.MODEL_ID_FIELD, 123)));

assertThrows(
ClassCastException.class,
() -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.NESTED_PATH_FIELD, 123))
);

assertThrows(JsonSyntaxException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.SOURCE_FIELD, "123")));

// although it will be parsed as integer, but the parameters value should always be String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -22,6 +21,7 @@

public class NeuralSparseSearchToolIT extends BaseAgentToolsIT {
public static String TEST_INDEX_NAME = "test_index";
public static String TEST_NESTED_INDEX_NAME = "test_index_nested";

private String modelId;
private String registerAgentRequestBody;
Expand Down Expand Up @@ -64,12 +64,55 @@ private void prepareIndex() {
addDocToIndex(TEST_INDEX_NAME, "2", List.of("text", "embedding"), List.of("text doc 3", Map.of("test", 5, "a", 6)));
}

@SneakyThrows
private void prepareNestedIndex() {
createIndexWithConfiguration(
TEST_NESTED_INDEX_NAME,
"{\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"text\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"embedding\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\":{\n"
+ " \"sparse\":{\n"
+ " \"type\":\"rank_features\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}"
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"0",
List.of("text", "embedding"),
List.of("text doc 1", Map.of("sparse", List.of(Map.of("hello", 1, "world", 2))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"1",
List.of("text", "embedding"),
List.of("text doc 2", Map.of("sparse", List.of(Map.of("a", 3, "b", 4))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"2",
List.of("text", "embedding"),
List.of("text doc 3", Map.of("sparse", List.of(Map.of("test", 5, "a", 6))))
);
}

@Before
@SneakyThrows
public void setUp() {
super.setUp();
prepareModel();
prepareIndex();
prepareNestedIndex();
registerAgentRequestBody = Files
.readString(
Path
Expand Down Expand Up @@ -127,6 +170,23 @@ public void testNeuralSparseSearchToolInFlowAgent() {
);
}

public void testNeuralSparseSearchToolInFlowAgent_withNestedIndex() {
String registerAgentRequestBodyNested = registerAgentRequestBody;
registerAgentRequestBodyNested = registerAgentRequestBodyNested.replace("\"nested_path\": \"\"", "\"nested_path\": \"embedding\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"embedding_field\": \"embedding\"", "\"embedding_field\": \"embedding.sparse\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"index\": \"test_index\"", "\"index\": \"test_index_nested\"");
String agentId = createAgent(registerAgentRequestBodyNested);
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n"
+ "{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n",
result
);
}

public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
Expand Down
Loading

0 comments on commit 7a5d0d8

Please sign in to comment.