From 41b33b9628ede91e8bef93b1bff65f2636d6df39 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Tue, 12 Dec 2023 13:55:44 +0800 Subject: [PATCH] unit tests for search index tool Signed-off-by: yuye-aws --- .../ml/engine/tools/SearchIndexTool.java | 8 +- .../ml/engine/tools/SearchIndexToolTests.java | 104 ++++++++++++++++++ 2 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java index 449c4e2d15..78dd030ad6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.util.HashMap; @@ -32,7 +33,6 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; -import lombok.Builder; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -65,7 +65,6 @@ public class SearchIndexTool implements Tool { gson = new Gson(); } - @Builder public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { this.client = client; this.xContentRegistry = xContentRegistry; @@ -147,9 +146,8 @@ public void run(Map parameters, ActionListener listener) } else { client.search(searchRequest, actionListener); } - - } catch (Exception e) { - listener.onFailure(e); + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java new file mode 100644 index 0000000000..1229c6d9e9 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java @@ -0,0 +1,104 @@ +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.search.SearchModule; + +import lombok.SneakyThrows; + +public class SearchIndexToolTests { + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private Client client; + private SearchIndexTool mockedSearchIndexTool; + + @Before + @SneakyThrows + public void setup() { + client = mock(Client.class); + mockedSearchIndexTool = Mockito + .mock( + SearchIndexTool.class, + Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + } + + @Test + @SneakyThrows + public void testGetType() { + String type = mockedSearchIndexTool.getType(); + assertFalse(Strings.isNullOrEmpty(type)); + assertEquals("SearchIndexTool", type); + } + + @Test + @SneakyThrows + public void testValidate() { + Map parameters = Map.of("input", "{}"); + assertTrue(mockedSearchIndexTool.validate(parameters)); + } + + @Test + @SneakyThrows + public void testValidateWithEmptyInput() { + Map parameters = Map.of(); + assertFalse(mockedSearchIndexTool.validate(parameters)); + } + + @Test + public void testRunWithNormalIndex() { + String inputString = "{\"index\": \"test-index\", \"query\": {\"size\": 10000}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + public void testRunWithConnectorIndex() { + String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"size\": 10000}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); + } + + @Test + public void testRunWithModelIndex() { + String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"size\": 10000}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithEmptyQuery() { + String inputString = "{\"index\": \"test_index\"}"; + Map parameters = Map.of("input", inputString); + Exception exception = assertThrows(IllegalArgumentException.class, () -> mockedSearchIndexTool.run(parameters, null)); + assertEquals("[query] is null or empty, can not process it.", exception.getMessage()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + Mockito.verify(client, Mockito.never()).search(any(), any()); + } +}