Skip to content

Commit

Permalink
unit tests for search index tool
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Dec 12, 2023
1 parent 3987556 commit 41b33b9
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
Expand All @@ -32,7 +33,6 @@
import com.google.gson.Gson;
import com.google.gson.JsonObject;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -65,7 +65,6 @@ public class SearchIndexTool implements Tool {
gson = new Gson();
}

@Builder
public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
Expand Down Expand Up @@ -147,9 +146,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
} else {
client.search(searchRequest, actionListener);
}

} catch (Exception e) {
listener.onFailure(e);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package org.opensearch.ml.engine.tools;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.search.SearchModule;

import lombok.SneakyThrows;

public class SearchIndexToolTests {
static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry(
new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()
);

private Client client;
private SearchIndexTool mockedSearchIndexTool;

@Before
@SneakyThrows
public void setup() {
client = mock(Client.class);
mockedSearchIndexTool = Mockito
.mock(
SearchIndexTool.class,
Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}

@Test
@SneakyThrows
public void testGetType() {
String type = mockedSearchIndexTool.getType();
assertFalse(Strings.isNullOrEmpty(type));
assertEquals("SearchIndexTool", type);
}

@Test
@SneakyThrows
public void testValidate() {
Map<String, String> parameters = Map.of("input", "{}");
assertTrue(mockedSearchIndexTool.validate(parameters));
}

@Test
@SneakyThrows
public void testValidateWithEmptyInput() {
Map<String, String> parameters = Map.of();
assertFalse(mockedSearchIndexTool.validate(parameters));
}

@Test
public void testRunWithNormalIndex() {
String inputString = "{\"index\": \"test-index\", \"query\": {\"size\": 10000}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, times(1)).search(any(), any());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
}

@Test
public void testRunWithConnectorIndex() {
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"size\": 10000}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any());
}

@Test
public void testRunWithModelIndex() {
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"size\": 10000}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any());
}

@Test
@SneakyThrows
public void testRunWithEmptyQuery() {
String inputString = "{\"index\": \"test_index\"}";
Map<String, String> parameters = Map.of("input", inputString);
Exception exception = assertThrows(IllegalArgumentException.class, () -> mockedSearchIndexTool.run(parameters, null));
assertEquals("[query] is null or empty, can not process it.", exception.getMessage());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
Mockito.verify(client, Mockito.never()).search(any(), any());
}
}

0 comments on commit 41b33b9

Please sign in to comment.