From b27a66497efe24017999ebeed4728f241d8916ed Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 28 Dec 2023 16:37:41 -0800 Subject: [PATCH] Add RAGTool (#78) (#79) * increase AbstractRetrieverToolTests code coverage * add RAGTool * Change exception handling from input field in RAGTool --------- (cherry picked from commit 24d5cf985dec87d2f0797759c2765a18d9d6cbd9) Signed-off-by: Mingshi Liu Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Signed-off-by: yuye-aws --- .../org/opensearch/agent/tools/RAGTool.java | 279 +++++++++++++++ .../opensearch/agent/tools/RAGToolTests.java | 330 ++++++++++++++++++ 2 files changed, 609 insertions(+) create mode 100644 src/main/java/org/opensearch/agent/tools/RAGTool.java create mode 100644 src/test/java/org/opensearch/agent/tools/RAGToolTests.java diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java new file mode 100644 index 00000000..7c9c26c5 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -0,0 +1,279 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.apache.commons.lang3.StringEscapeUtils.escapeJson; +import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.toJson; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports retrieving helpful information to optimize the output of the large language model to answer questions.. + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(RAGTool.TYPE) +public class RAGTool extends AbstractRetrieverTool { + public static final String TYPE = "RAGTool"; + public static String DEFAULT_DESCRIPTION = + "Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions."; + public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id"; + public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String OUTPUT_FIELD = "output_field"; + private String name = TYPE; + private String description = DEFAULT_DESCRIPTION; + private Client client; + private String inferenceModelId; + private NamedXContentRegistry xContentRegistry; + private String index; + private String embeddingField; + private String[] sourceFields; + private String embeddingModelId; + private Integer docSize; + private Integer k; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + @Builder + public RAGTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer k, + Integer docSize, + String embeddingModelId, + String inferenceModelId + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.index = index; + this.embeddingField = embeddingField; + this.sourceFields = sourceFields; + this.embeddingModelId = embeddingModelId; + this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; + this.k = k == null ? DEFAULT_K : k; + this.inferenceModelId = inferenceModelId; + + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // getQueryBody is not used in RAGTool + @Override + protected String getQueryBody(String queryText) { + return queryText; + } + + @Override + public void run(Map parameters, ActionListener listener) { + String input = null; + + if (!this.validate(parameters)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } + + try { + String question = parameters.get(INPUT_FIELD); + input = gson.fromJson(question, String.class); + } catch (Exception e) { + log.error("Failed to read question from " + INPUT_FIELD, e); + listener.onFailure(new IllegalArgumentException("Failed to read question from " + INPUT_FIELD)); + return; + } + + Map params = new HashMap<>(); + VectorDBTool.Factory.getInstance().init(client, xContentRegistry); + params.put(VectorDBTool.INDEX_FIELD, this.index); + params.put(VectorDBTool.EMBEDDING_FIELD, this.embeddingField); + params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(this.sourceFields)); + params.put(VectorDBTool.MODEL_ID_FIELD, this.embeddingModelId); + params.put(VectorDBTool.DOC_SIZE_FIELD, String.valueOf(this.docSize)); + params.put(VectorDBTool.K_FIELD, String.valueOf(this.k)); + VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params); + + String embeddingInput = input; + ActionListener actionListener = ActionListener.wrap(r -> { + T vectorDBToolOutput; + + if (r.equals("Can not get any match from search result.")) { + vectorDBToolOutput = (T) ""; + } else { + Gson gson = new Gson(); + String[] hits = r.toString().split("\n"); + + StringBuilder resultBuilder = new StringBuilder(); + for (String hit : hits) { + JsonObject jsonObject = gson.fromJson(hit, JsonObject.class); + String id = jsonObject.get("_id").getAsString(); + JsonObject source = jsonObject.getAsJsonObject("_source"); + + resultBuilder.append("_id: ").append(id).append("\n"); + resultBuilder.append("_source: ").append(source.toString()).append("\n"); + } + + vectorDBToolOutput = (T) gson.toJson(resultBuilder.toString()); + } + + Map tmpParameters = new HashMap<>(); + tmpParameters.putAll(parameters); + + if (vectorDBToolOutput instanceof List + && !((List) vectorDBToolOutput).isEmpty() + && ((List) vectorDBToolOutput).get(0) instanceof ModelTensors) { + ModelTensors tensors = (ModelTensors) ((List) vectorDBToolOutput).get(0); + Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response"); + tmpParameters.put(OUTPUT_FIELD, response + ""); + } else if (vectorDBToolOutput instanceof ModelTensor) { + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(((ModelTensor) vectorDBToolOutput).getDataAsMap()))); + } else { + if (vectorDBToolOutput instanceof String) { + tmpParameters.put(OUTPUT_FIELD, (String) vectorDBToolOutput); + } else { + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(vectorDBToolOutput.toString()))); + } + } + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + ActionRequest request = new MLPredictionTaskRequest(inferenceModelId, mlInput, null); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + inferenceModelId, e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search index.", e); + listener.onFailure(e); + }); + vectorDBTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener); + + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + String question = parameters.get(INPUT_FIELD); + return question != null && !question.trim().isEmpty(); + } + + /** + * Factory class to create RAGTool + */ + public static class Factory extends AbstractRetrieverTool.Factory { + private Client client; + private NamedXContentRegistry xContentRegistry; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (RAGTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public RAGTool create(Map params) { + String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD); + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .embeddingModelId(embeddingModelId) + .docSize(docSize) + .inferenceModelId(inferenceModelId) + .build(); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java new file mode 100644 index 00000000..79bfcebf --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java @@ -0,0 +1,330 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.opensearch.agent.tools.AbstractRetrieverTool.*; +import static org.opensearch.agent.tools.AbstractRetrieverToolTests.*; +import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +import lombok.SneakyThrows; + +public class RAGToolTests { + public static final String TEST_QUERY_TEXT = "hello?"; + public static final String TEST_EMBEDDING_FIELD = "test_embedding"; + public static final String TEST_EMBEDDING_MODEL_ID = "1234"; + public static final String TEST_INFERENCE_MODEL_ID = "1234"; + + public static final String TEST_NEURAL_QUERY = "{\"query\":{\"neural\":{\"" + + TEST_EMBEDDING_FIELD + + "\":{\"query_text\":\"" + + TEST_QUERY_TEXT + + "\",\"model_id\":\"" + + TEST_EMBEDDING_MODEL_ID + + "\",\"k\":" + + DEFAULT_K + + "}}}" + + " }";; + private RAGTool ragTool; + private String mockedSearchResponseString; + private String mockedEmptySearchResponseString; + @Mock + private Parser mockOutputParser; + @Mock + private Client client; + @Mock + private ActionListener listener; + private Map params; + + @Before + @SneakyThrows + public void setup() { + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_empty_search_response.json")) { + if (searchResponseIns != null) { + mockedEmptySearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + + client = mock(Client.class); + listener = mock(ActionListener.class); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + + params = new HashMap<>(); + params.put(RAGTool.INDEX_FIELD, TEST_INDEX); + params.put(RAGTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(RAGTool.SOURCE_FIELD, gson.toJson(TEST_SOURCE_FIELDS)); + params.put(RAGTool.EMBEDDING_MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); + params.put(RAGTool.INFERENCE_MODEL_ID_FIELD, TEST_INFERENCE_MODEL_ID); + params.put(RAGTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + params.put(VectorDBTool.K_FIELD, DEFAULT_K); + ragTool = RAGTool.Factory.getInstance().create(params); + } + + @Test + public void testValidate() { + assertTrue(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hi"))); + assertFalse(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""))); + assertFalse(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "))); + assertFalse(ragTool.validate(Map.of("test", " "))); + assertFalse(ragTool.validate(new HashMap<>())); + assertFalse(ragTool.validate(null)); + } + + @Test + public void testGetAttributes() { + assertEquals(ragTool.getVersion(), null); + assertEquals(ragTool.getType(), RAGTool.TYPE); + assertEquals(ragTool.getIndex(), TEST_INDEX); + assertEquals(ragTool.getDocSize(), TEST_DOC_SIZE); + assertEquals(ragTool.getSourceFields(), TEST_SOURCE_FIELDS); + assertEquals(ragTool.getEmbeddingField(), TEST_EMBEDDING_FIELD); + assertEquals(ragTool.getEmbeddingModelId(), TEST_EMBEDDING_MODEL_ID); + assertEquals(ragTool.getK(), DEFAULT_K); + assertEquals(ragTool.getInferenceModelId(), TEST_INFERENCE_MODEL_ID); + } + + @Test + public void testSetName() { + assertEquals(ragTool.getName(), RAGTool.TYPE); + ragTool.setName("test-tool"); + assertEquals(ragTool.getName(), "test-tool"); + } + + @Test + public void testGetQueryBodySuccess() { + assertEquals(ragTool.getQueryBody(TEST_QUERY_TEXT), TEST_QUERY_TEXT); + } + + @Test + public void testOutputParser() throws IOException { + + NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ragTool.setOutputParser(mockOutputParser); + ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunWithEmptySearchResponse() throws IOException { + NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithRuntimeExceptionDuringSearch() { + NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Failed to search index")); + return null; + }).when(client).search(any(), any()); + ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to search index", argumentCaptor.getValue().getMessage()); + } + + @Test + @SneakyThrows + public void testRunWithRuntimeExceptionDuringExecute() { + NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Failed to run model " + TEST_INFERENCE_MODEL_ID)); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to run model " + TEST_INFERENCE_MODEL_ID, argumentCaptor.getValue().getMessage()); + } + + @Test(expected = IllegalArgumentException.class) + public void testRunWithEmptyInput() { + ActionListener listener = mock(ActionListener.class); + ragTool.run(Map.of(INPUT_FIELD, ""), listener); + } + + @Test + public void testRunWithMalformedInput() throws IOException { + ActionListener listener = mock(ActionListener.class); + ragTool.run(Map.of(INPUT_FIELD, "{hello?"), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to read question from " + INPUT_FIELD, argumentCaptor.getValue().getMessage()); + + } + + @Test + public void testFactory() { + RAGTool.Factory factoryMock = new RAGTool.Factory(); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(RAGTool.DEFAULT_DESCRIPTION, defaultDescription); + assertNotNull(RAGTool.Factory.getInstance()); + RAGTool rAGtool1 = factoryMock.create(params); + + RAGTool rAGtool2 = new RAGTool( + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + TEST_INDEX, + TEST_EMBEDDING_FIELD, + TEST_SOURCE_FIELDS, + DEFAULT_K, + TEST_DOC_SIZE, + TEST_EMBEDDING_MODEL_ID, + TEST_INFERENCE_MODEL_ID + ); + + assertEquals(rAGtool1.getClient(), rAGtool2.getClient()); + assertEquals(rAGtool1.getK(), rAGtool2.getK()); + assertEquals(rAGtool1.getInferenceModelId(), rAGtool2.getInferenceModelId()); + assertEquals(rAGtool1.getName(), rAGtool2.getName()); + assertEquals(rAGtool1.getDocSize(), rAGtool2.getDocSize()); + assertEquals(rAGtool1.getIndex(), rAGtool2.getIndex()); + assertEquals(rAGtool1.getEmbeddingModelId(), rAGtool2.getEmbeddingModelId()); + assertEquals(rAGtool1.getEmbeddingField(), rAGtool2.getEmbeddingField()); + assertEquals(rAGtool1.getSourceFields(), rAGtool2.getSourceFields()); + assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry()); + + } + + private static NamedXContentRegistry getNeuralQueryNamedXContentRegistry() { + QueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); + + List entries = new ArrayList<>(); + NamedXContentRegistry.Entry entry = new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField("neural"), (p, c) -> { + p.map(); + return matchAllQueryBuilder; + }); + entries.add(entry); + NamedXContentRegistry mockNamedXContentRegistry = new NamedXContentRegistry(entries); + return mockNamedXContentRegistry; + } + + private static ModelTensorOutput getMlModelTensorOutput() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + return mlModelTensorOutput; + } +}