diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java index 5e0faa9c..5f5803f0 100644 --- a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -5,14 +5,10 @@ package org.opensearch.agent.tools; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; +import static org.opensearch.agent.tools.AbstractRetrieverTool.DEFAULT_DESCRIPTION; import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -23,6 +19,7 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -32,6 +29,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.search.SearchModule; import lombok.SneakyThrows; @@ -180,4 +178,59 @@ public void testValidate() { assertFalse(mockedImpl.validate(new HashMap<>())); assertFalse(mockedImpl.validate(null)); } + + @Test + public void testGetAttributes() { + assertEquals(mockedImpl.getVersion(), null); + assertEquals(mockedImpl.getIndex(), TEST_INDEX); + assertEquals(mockedImpl.getDocSize(), TEST_DOC_SIZE); + assertEquals(mockedImpl.getSourceFields(), TEST_SOURCE_FIELDS); + assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY); + } + + @Test + public void testGetQueryBodySuccess() { + assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY); + } + + @Test + @SneakyThrows + public void testRunWithRuntimeException() { + Client client = mock(Client.class); + mockedImpl.setClient(client); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Failed to search index")); + return null; + }).when(client).search(any(), any()); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to search index", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testFactory() { + // Create a mock object of the abstract Factory class + Client client = mock(Client.class); + AbstractRetrieverTool.Factory factoryMock = new AbstractRetrieverTool.Factory<>() { + public PPLTool create(Map params) { + return null; + } + }; + + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + + assertNotNull(factoryMock.client); + assertNotNull(factoryMock.xContentRegistry); + assertEquals(client, factoryMock.client); + assertEquals(TEST_XCONTENT_REGISTRY_FOR_QUERY, factoryMock.xContentRegistry); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(DEFAULT_DESCRIPTION, defaultDescription); + } }