diff --git a/data-prepper-plugins/http-source-common/build.gradle b/data-prepper-plugins/http-source-common/build.gradle index 49b282a1f2..60002782c0 100644 --- a/data-prepper-plugins/http-source-common/build.gradle +++ b/data-prepper-plugins/http-source-common/build.gradle @@ -9,7 +9,10 @@ plugins { dependencies { implementation project(':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:armeria-common') + implementation project(':data-prepper-plugins:blocking-buffer') implementation libs.armeria.core + implementation libs.commons.io implementation 'software.amazon.awssdk:acm' implementation 'software.amazon.awssdk:s3' implementation 'software.amazon.awssdk:apache-client' diff --git a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpService.java b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpService.java new file mode 100644 index 0000000000..0e23a035f6 --- /dev/null +++ b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpService.java @@ -0,0 +1,4 @@ +package org.opensearch.dataprepper.http; + +public interface BaseHttpService { +} diff --git a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpSource.java b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpSource.java new file mode 100644 index 0000000000..4e07dfac9c --- /dev/null +++ b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/BaseHttpSource.java @@ -0,0 +1,194 @@ +package org.opensearch.dataprepper.http; + +import com.linecorp.armeria.server.HttpService; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.encoding.DecodingService; +import com.linecorp.armeria.server.healthcheck.HealthCheckService; +import com.linecorp.armeria.server.throttling.ThrottlingService; +import org.opensearch.dataprepper.HttpRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; +import org.opensearch.dataprepper.http.certificate.CertificateProviderFactory; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.ByteDecoder; +import org.opensearch.dataprepper.model.codec.JsonDecoder; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.Source; +import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; +import org.opensearch.dataprepper.plugins.certificate.model.Certificate; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; +import org.slf4j.Logger; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.function.Function; + +/** + * BaseHttpSource class holds the common http related source functionality including starting the armeria server and authentication handling. + * HTTP based sources should use this functionality when implementing the respective source. + */ +public abstract class BaseHttpSource> implements Source { + public static final String REGEX_HEALTH = "regex:^/(?!health$).*$"; + public static final String SERVER_CONNECTIONS = "serverConnections"; + private static final String PIPELINE_NAME_PLACEHOLDER = "${pipelineName}"; + private static final String HTTP_HEALTH_CHECK_PATH = "/health"; + private final HttpServerConfig sourceConfig; + private final CertificateProviderFactory certificateProviderFactory; + private final ArmeriaHttpAuthenticationProvider authenticationProvider; + private final HttpRequestExceptionHandler httpRequestExceptionHandler; + private final String pipelineName; + private final String sourceName; + private final Logger logger; + private final PluginMetrics pluginMetrics; + private Server server; + private ByteDecoder byteDecoder; + + public BaseHttpSource(final HttpServerConfig sourceConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, + final PipelineDescription pipelineDescription, final String sourceName, final Logger logger) { + this.sourceConfig = sourceConfig; + this.pluginMetrics = pluginMetrics; + this.pipelineName = pipelineDescription.getPipelineName(); + this.sourceName = sourceName; + this.logger = logger; + this.byteDecoder = new JsonDecoder(); + this.certificateProviderFactory = new CertificateProviderFactory(sourceConfig); + final PluginModel authenticationConfiguration = sourceConfig.getAuthentication(); + final PluginSetting authenticationPluginSetting; + + if (authenticationConfiguration == null || authenticationConfiguration.getPluginName().equals(ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME)) { + logger.warn("Creating {} source without authentication. This is not secure.", sourceName); + logger.warn("In order to set up Http Basic authentication for the {} source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/http-source#authentication-configurations", sourceName); + } + + if (authenticationConfiguration != null) { + authenticationPluginSetting = + new PluginSetting(authenticationConfiguration.getPluginName(), authenticationConfiguration.getPluginSettings()); + } else { + authenticationPluginSetting = + new PluginSetting(ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, Collections.emptyMap()); + } + authenticationPluginSetting.setPipelineName(pipelineName); + authenticationProvider = pluginFactory.loadPlugin(ArmeriaHttpAuthenticationProvider.class, authenticationPluginSetting); + httpRequestExceptionHandler = new HttpRequestExceptionHandler(pluginMetrics); + } + + @Override + public void start(final Buffer buffer) { + if (buffer == null) { + throw new IllegalStateException("Buffer provided is null"); + } + if (server == null) { + final ServerBuilder sb = Server.builder(); + + sb.disableServerHeader(); + + if (sourceConfig.isSsl()) { + logger.info("Creating {} source with SSL/TLS enabled.", sourceName); + final CertificateProvider certificateProvider = certificateProviderFactory.getCertificateProvider(); + final Certificate certificate = certificateProvider.getCertificate(); + // TODO: enable encrypted key with password + sb.https(sourceConfig.getPort()).tls( + new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8)), + new ByteArrayInputStream(certificate.getPrivateKey().getBytes(StandardCharsets.UTF_8) + ) + ); + } else { + logger.warn("Creating {} source without SSL/TLS. This is not secure.", sourceName); + logger.warn("In order to set up TLS for the {} source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/http-source#ssl", sourceName); + sb.http(sourceConfig.getPort()); + } + + if (sourceConfig.getAuthentication() != null) { + final Optional> optionalAuthDecorator = authenticationProvider.getAuthenticationDecorator(); + + if (sourceConfig.isUnauthenticatedHealthCheck()) { + optionalAuthDecorator.ifPresent(authDecorator -> sb.decorator(REGEX_HEALTH, authDecorator)); + } else { + optionalAuthDecorator.ifPresent(sb::decorator); + } + } + + sb.maxNumConnections(sourceConfig.getMaxConnectionCount()); + sb.requestTimeout(Duration.ofMillis(sourceConfig.getRequestTimeoutInMillis())); + if (sourceConfig.getMaxRequestLength() != null) { + sb.maxRequestLength(sourceConfig.getMaxRequestLength().getBytes()); + } + final int threads = sourceConfig.getThreadCount(); + final ScheduledThreadPoolExecutor blockingTaskExecutor = new ScheduledThreadPoolExecutor(threads); + sb.blockingTaskExecutor(blockingTaskExecutor, true); + final int maxPendingRequests = sourceConfig.getMaxPendingRequests(); + final LogThrottlingStrategy logThrottlingStrategy = new LogThrottlingStrategy( + maxPendingRequests, blockingTaskExecutor.getQueue()); + final LogThrottlingRejectHandler logThrottlingRejectHandler = new LogThrottlingRejectHandler(maxPendingRequests, pluginMetrics); + + final String httpSourcePath = sourceConfig.getPath().replace(PIPELINE_NAME_PLACEHOLDER, pipelineName); + sb.decorator(httpSourcePath, ThrottlingService.newDecorator(logThrottlingStrategy, logThrottlingRejectHandler)); + final BaseHttpService httpService = getHttpService(sourceConfig.getBufferTimeoutInMillis(), buffer, pluginMetrics); + + if (CompressionOption.NONE.equals(sourceConfig.getCompression())) { + sb.annotatedService(httpSourcePath, httpService, httpRequestExceptionHandler); + } else { + sb.annotatedService(httpSourcePath, httpService, DecodingService.newDecorator(), httpRequestExceptionHandler); + } + + if (sourceConfig.hasHealthCheckService()) { + logger.info("{} source health check is enabled", sourceName); + sb.service(HTTP_HEALTH_CHECK_PATH, HealthCheckService.builder().longPolling(0).build()); + } + + server = sb.build(); + pluginMetrics.gauge(SERVER_CONNECTIONS, server, Server::numConnections); + } + + try { + server.start().get(); + } catch (ExecutionException ex) { + if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { + throw (RuntimeException) ex.getCause(); + } else { + throw new RuntimeException(ex); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new RuntimeException(ex); + } + logger.info("Started {} source on port {}", sourceName, sourceConfig.getPort()); + } + + @Override + public ByteDecoder getDecoder() { + return byteDecoder; + } + + @Override + public void stop() { + if (server != null) { + try { + server.stop().get(); + } catch (ExecutionException ex) { + if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { + throw (RuntimeException) ex.getCause(); + } else { + throw new RuntimeException(ex); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new RuntimeException(ex); + } + } + logger.info("Stopped {} source.", sourceName); + } + + public abstract BaseHttpService getHttpService(int bufferTimeoutInMillis, Buffer buffer, PluginMetrics pluginMetrics); + +} diff --git a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/MultiLineJsonCodec.java b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/MultiLineJsonCodec.java index c0e1885f25..a1090835d3 100644 --- a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/MultiLineJsonCodec.java +++ b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/MultiLineJsonCodec.java @@ -4,29 +4,32 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.linecorp.armeria.common.HttpData; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; -public class MultiLineJsonCodec implements Codec>> { +public class MultiLineJsonCodec implements Codec>> { private static final ObjectMapper objectMapper = new ObjectMapper(); private static final String REGEX = "\\r?\\n"; - private static final TypeReference> MAP_TYPE_REFERENCE = - new TypeReference>() {}; + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { + }; + private static final Pattern multiLineJsonSplitPattern = Pattern.compile(REGEX); + + private static boolean isInvalidLine(final String str) { + return str == null || str.isEmpty() || str.isBlank(); + } @Override public List> parse(HttpData httpData) throws IOException { List> jsonListData = new ArrayList<>(); String requestBody = new String(httpData.toInputStream().readAllBytes(), StandardCharsets.UTF_8); - List jsonLines = Arrays.asList(requestBody.split(REGEX)); + String[] jsonLines = multiLineJsonSplitPattern.split(requestBody); - for (String jsonLine: jsonLines) { + for (String jsonLine : jsonLines) { if (isInvalidLine(jsonLine)) { throw new IOException("Error processing request payload."); } @@ -34,8 +37,4 @@ public List> parse(HttpData httpData) throws IOException { } return jsonListData; } - - private static boolean isInvalidLine(final String str) { - return str == null || str.isEmpty() || str.isBlank(); - } } diff --git a/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/BaseHttpSourceTest.java b/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/BaseHttpSourceTest.java new file mode 100644 index 0000000000..675b84e3e7 --- /dev/null +++ b/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/BaseHttpSourceTest.java @@ -0,0 +1,273 @@ +package org.opensearch.dataprepper.http; + +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import org.apache.commons.io.IOUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class BaseHttpSourceTest { + private final String PLUGIN_NAME = "opensearch_api"; + private final String TEST_PIPELINE_NAME = "test_pipeline"; + private final int DEFAULT_REQUEST_TIMEOUT_MS = 10_000; + private final int DEFAULT_THREAD_COUNT = 200; + private final int MAX_CONNECTIONS_COUNT = 500; + private final int MAX_PENDING_REQUESTS_COUNT = 1024; + private final String sourceName = "basic-http-api-source"; + + private final String TEST_SSL_CERTIFICATE_FILE = + Objects.requireNonNull(getClass().getClassLoader().getResource("test_cert.crt")).getFile(); + private final String TEST_SSL_KEY_FILE = + Objects.requireNonNull(getClass().getClassLoader().getResource("test_decrypted_key.key")).getFile(); + + @Mock + private ServerBuilder serverBuilder; + + @Mock + private Server server; + + @Mock + private CompletableFuture completableFuture; + + @Mock + private BlockingBuffer> testBuffer; + + @Mock + private BaseHttpService BaseHttpService; + + private BaseHttpSource> httpApiSource; + private HttpServerConfig sourceConfig; + private PluginMetrics pluginMetrics; + private PluginFactory pluginFactory; + private PipelineDescription pipelineDescription; + + private BaseHttpSource> createObjectUnderTest() { + return new BaseHttpSource<>(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription, sourceName, LoggerFactory.getLogger(BaseHttpService.class)) { + @Override + public BaseHttpService getHttpService(int bufferTimeoutInMillis, Buffer> buffer, PluginMetrics pluginMetrics) { + return BaseHttpService; + } + }; + } + + @BeforeEach + public void setUp() { + lenient().when(serverBuilder.annotatedService(any())).thenReturn(serverBuilder); + lenient().when(serverBuilder.http(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.https(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.build()).thenReturn(server); + lenient().when(server.start()).thenReturn(completableFuture); + + sourceConfig = mock(HttpServerConfig.class); + lenient().when(sourceConfig.getRequestTimeoutInMillis()).thenReturn(DEFAULT_REQUEST_TIMEOUT_MS); + lenient().when(sourceConfig.getPath()).thenReturn("/path"); + lenient().when(sourceConfig.getPort()).thenReturn(9092); + lenient().when(sourceConfig.getThreadCount()).thenReturn(DEFAULT_THREAD_COUNT); + lenient().when(sourceConfig.getMaxConnectionCount()).thenReturn(MAX_CONNECTIONS_COUNT); + lenient().when(sourceConfig.getMaxPendingRequests()).thenReturn(MAX_PENDING_REQUESTS_COUNT); + lenient().when(sourceConfig.hasHealthCheckService()).thenReturn(true); + lenient().when(sourceConfig.getCompression()).thenReturn(CompressionOption.NONE); + + pluginMetrics = PluginMetrics.fromNames(PLUGIN_NAME, TEST_PIPELINE_NAME); + + pluginFactory = mock(PluginFactory.class); + final ArmeriaHttpAuthenticationProvider authenticationProvider = mock(ArmeriaHttpAuthenticationProvider.class); + when(pluginFactory.loadPlugin(eq(ArmeriaHttpAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(authenticationProvider); + + pipelineDescription = mock(PipelineDescription.class); + when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + + httpApiSource = createObjectUnderTest(); + } + + @AfterEach + public void cleanUp() { + if (httpApiSource != null) { + httpApiSource.stop(); + } + } + + @Test + public void testServerStartCertFileSuccess() throws IOException { + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + when(server.stop()).thenReturn(completableFuture); + + final Path certFilePath = new File(TEST_SSL_CERTIFICATE_FILE).toPath(); + final Path keyFilePath = new File(TEST_SSL_KEY_FILE).toPath(); + final String certAsString = Files.readString(certFilePath); + final String keyAsString = Files.readString(keyFilePath); + + when(sourceConfig.isSsl()).thenReturn(true); + when(sourceConfig.getSslCertificateFile()).thenReturn(TEST_SSL_CERTIFICATE_FILE); + when(sourceConfig.getSslKeyFile()).thenReturn(TEST_SSL_KEY_FILE); + final BaseHttpSource> objectUnderTest = createObjectUnderTest(); + objectUnderTest.start(testBuffer); + objectUnderTest.stop(); + + final ArgumentCaptor certificateIs = ArgumentCaptor.forClass(InputStream.class); + final ArgumentCaptor privateKeyIs = ArgumentCaptor.forClass(InputStream.class); + verify(serverBuilder).tls(certificateIs.capture(), privateKeyIs.capture()); + final String actualCertificate = IOUtils.toString(certificateIs.getValue(), StandardCharsets.UTF_8.name()); + final String actualPrivateKey = IOUtils.toString(privateKeyIs.getValue(), StandardCharsets.UTF_8.name()); + assertThat(actualCertificate, is(certAsString)); + assertThat(actualPrivateKey, is(keyAsString)); + } + } + + @Test + public void testDoubleStart() { + // starting server + httpApiSource.start(testBuffer); + // double start server + Assertions.assertThrows(IllegalStateException.class, () -> httpApiSource.start(testBuffer)); + } + + @Test + public void testStartWithEmptyBuffer() { + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + Assertions.assertThrows(IllegalStateException.class, () -> httpApiSource.start(null)); + } + + @Test + public void testStartWithServerExecutionExceptionNoCause() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + when(completableFuture.get()).thenThrow(new ExecutionException("", null)); + + // When/Then + Assertions.assertThrows(RuntimeException.class, () -> httpApiSource.start(testBuffer)); + } + } + + @Test + public void testStartWithServerExecutionExceptionWithCause() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + final NullPointerException expCause = new NullPointerException(); + when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); + + // When/Then + final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, () -> httpApiSource.start(testBuffer)); + Assertions.assertEquals(expCause, ex); + } + } + + @Test + public void testStartWithInterruptedException() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + when(completableFuture.get()).thenThrow(new InterruptedException()); + + // When/Then + Assertions.assertThrows(RuntimeException.class, () -> httpApiSource.start(testBuffer)); + Assertions.assertTrue(Thread.interrupted()); + } + } + + @Test + public void testStopWithServerExecutionExceptionNoCause() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + httpApiSource.start(testBuffer); + when(server.stop()).thenReturn(completableFuture); + + // When/Then + when(completableFuture.get()).thenThrow(new ExecutionException("", null)); + Assertions.assertThrows(RuntimeException.class, httpApiSource::stop); + } + } + + @Test + public void testStopWithServerExecutionExceptionWithCause() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + httpApiSource.start(testBuffer); + when(server.stop()).thenReturn(completableFuture); + final NullPointerException expCause = new NullPointerException(); + when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); + + // When/Then + final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, httpApiSource::stop); + Assertions.assertEquals(expCause, ex); + } + } + + @Test + public void testStopWithInterruptedException() throws ExecutionException, InterruptedException { + // Prepare + final BaseHttpSource> httpApiSource = createObjectUnderTest(); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + httpApiSource.start(testBuffer); + when(server.stop()).thenReturn(completableFuture); + when(completableFuture.get()).thenThrow(new InterruptedException()); + + // When/Then + Assertions.assertThrows(RuntimeException.class, httpApiSource::stop); + Assertions.assertTrue(Thread.interrupted()); + } + } + + @Test + public void testRunAnotherSourceWithSamePort() { + // starting server + httpApiSource.start(testBuffer); + + final BaseHttpSource> secondHttpAPISource = createObjectUnderTest(); + //Expect RuntimeException because when port is already in use, BindException is thrown which is not RuntimeException + Assertions.assertThrows(RuntimeException.class, () -> secondHttpAPISource.start(testBuffer)); + } + +} diff --git a/data-prepper-plugins/opensearch-api-source/build.gradle b/data-prepper-plugins/opensearch-api-source/build.gradle index 874cbc4781..3348be6034 100644 --- a/data-prepper-plugins/opensearch-api-source/build.gradle +++ b/data-prepper-plugins/opensearch-api-source/build.gradle @@ -15,6 +15,7 @@ dependencies { implementation project(':data-prepper-plugins:armeria-common') implementation libs.armeria.core implementation libs.commons.io + implementation libs.commons.lang3 implementation 'software.amazon.awssdk:acm' implementation 'software.amazon.awssdk:s3' implementation 'software.amazon.awssdk:apache-client' @@ -31,7 +32,7 @@ jacocoTestCoverageVerification { violationRules { rule { //in addition to core projects rule limit { - minimum = 0.90 + minimum = 1.0 } } } diff --git a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIService.java b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIService.java index d57da3632e..90cb77a21f 100644 --- a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIService.java +++ b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIService.java @@ -5,9 +5,20 @@ package org.opensearch.dataprepper.plugins.source.opensearchapi; +import com.linecorp.armeria.common.AggregatedHttpRequest; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.Blocking; import com.linecorp.armeria.server.annotation.Param; -import io.micrometer.common.util.StringUtils; +import com.linecorp.armeria.server.annotation.Post; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Timer; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.dataprepper.http.BaseHttpService; import org.opensearch.dataprepper.http.codec.MultiLineJsonCodec; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.buffer.Buffer; @@ -16,36 +27,26 @@ import org.opensearch.dataprepper.model.event.JacksonEvent; import org.opensearch.dataprepper.model.opensearch.OpenSearchBulkActions; import org.opensearch.dataprepper.model.record.Record; -import com.linecorp.armeria.common.AggregatedHttpRequest; -import com.linecorp.armeria.common.HttpData; -import com.linecorp.armeria.common.HttpResponse; -import com.linecorp.armeria.common.HttpStatus; -import com.linecorp.armeria.server.annotation.Blocking; -import com.linecorp.armeria.server.annotation.Post; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.DistributionSummary; -import io.micrometer.core.instrument.Timer; +import org.opensearch.dataprepper.plugins.source.opensearchapi.model.BulkAPIEventMetadataKeyAttributes; import org.opensearch.dataprepper.plugins.source.opensearchapi.model.BulkAPIRequestParams; import org.opensearch.dataprepper.plugins.source.opensearchapi.model.BulkActionAndMetadataObject; -import org.opensearch.dataprepper.plugins.source.opensearchapi.model.BulkAPIEventMetadataKeyAttributes; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; -import java.util.Optional; import java.util.List; import java.util.Map; -import java.util.Arrays; -import java.util.ArrayList; /* -* OpenSearch API Service class is responsible for handling bulk API requests. -* The bulk API is responsible for 1/ parsing the request body, 2/ validating against the schema for Document API (Bulk) and finally creating data prepper events. -* Bulk API supports query parameters "pipeline", "routing" and "refresh" -*/ + * OpenSearch API Service class is responsible for handling bulk API requests. + * The bulk API is responsible for 1/ parsing the request body, 2/ validating against the schema for Document API (Bulk) and finally creating data prepper events. + * Bulk API supports query parameters "pipeline", "routing" and "refresh" + */ @Blocking -public class OpenSearchAPIService { +public class OpenSearchAPIService implements BaseHttpService { //TODO: Will need to revisit the metrics per API endpoint public static final String REQUESTS_RECEIVED = "RequestsReceived"; @@ -77,39 +78,37 @@ public OpenSearchAPIService(final int bufferWriteTimeoutInMillis, final Buffer pipeline, @Param("routing") Optional routing) throws Exception { + @Param("pipeline") @Nullable String pipeline, + @Param("routing") @Nullable String routing) throws Exception { - requestsReceivedCounter.increment(); - payloadSizeSummary.record(aggregatedHttpRequest.content().length()); - - if(serviceRequestContext.isTimedOut()) { - return HttpResponse.of(HttpStatus.REQUEST_TIMEOUT); - } BulkAPIRequestParams bulkAPIRequestParams = BulkAPIRequestParams.builder() - .pipeline(pipeline.orElse("")) - .routing(routing.orElse("")) + .pipeline(pipeline) + .routing(routing) .build(); - return requestProcessDuration.recordCallable(() -> processBulkRequest(aggregatedHttpRequest, bulkAPIRequestParams)); + return requestProcessDuration.recordCallable(() -> processBulkRequest(serviceRequestContext, aggregatedHttpRequest, bulkAPIRequestParams)); } @Post("/{index}/_bulk") - public HttpResponse doPostBulkIndex(final ServiceRequestContext serviceRequestContext, final AggregatedHttpRequest aggregatedHttpRequest, @Param("index") Optional index, - @Param("pipeline") Optional pipeline, @Param("routing") Optional routing) throws Exception { + public HttpResponse doPostBulkIndex(final ServiceRequestContext serviceRequestContext, final AggregatedHttpRequest aggregatedHttpRequest, + @Param("index") String index, + @Param("pipeline") @Nullable String pipeline, + @Param("routing") @Nullable String routing) throws Exception { + BulkAPIRequestParams bulkAPIRequestParams = BulkAPIRequestParams.builder() + .index(index) + .pipeline(pipeline) + .routing(routing) + .build(); + return requestProcessDuration.recordCallable(() -> processBulkRequest(serviceRequestContext, aggregatedHttpRequest, bulkAPIRequestParams)); + } + + private HttpResponse processBulkRequest(final ServiceRequestContext serviceRequestContext, final AggregatedHttpRequest aggregatedHttpRequest, final BulkAPIRequestParams bulkAPIRequestParams) throws Exception { requestsReceivedCounter.increment(); payloadSizeSummary.record(aggregatedHttpRequest.content().length()); - if(serviceRequestContext.isTimedOut()) { + if (serviceRequestContext.isTimedOut()) { return HttpResponse.of(HttpStatus.REQUEST_TIMEOUT); } - BulkAPIRequestParams bulkAPIRequestParams = BulkAPIRequestParams.builder() - .index(index.orElse("")) - .pipeline(pipeline.orElse("")) - .routing(routing.orElse("")) - .build(); - return requestProcessDuration.recordCallable(() -> processBulkRequest(aggregatedHttpRequest, bulkAPIRequestParams)); - } - private HttpResponse processBulkRequest(final AggregatedHttpRequest aggregatedHttpRequest, final BulkAPIRequestParams bulkAPIRequestParams) throws Exception { final HttpData content = aggregatedHttpRequest.content(); List> bulkRequestPayloadList; @@ -142,10 +141,6 @@ private boolean isValidBulkAction(Map actionMap) { } private List> generateEventsFromBulkRequest(final List> bulkRequestPayloadList, final BulkAPIRequestParams bulkAPIRequestParams) throws Exception { - if (bulkRequestPayloadList.isEmpty()) { - throw new IOException("Invalid request data."); - } - List> records = new ArrayList<>(); Iterator> bulkRequestPayloadListIterator = bulkRequestPayloadList.iterator(); @@ -157,14 +152,14 @@ private List> generateEventsFromBulkRequest(final List> documentDataObject = Optional.empty(); + Map documentDataObject = null; if (!isDeleteAction) { if (!bulkRequestPayloadListIterator.hasNext()) { throw new IOException("Invalid request data."); } - documentDataObject = Optional.of(bulkRequestPayloadListIterator.next()); + documentDataObject = bulkRequestPayloadListIterator.next(); // Performing another validation check to make sure that the doc row is not a valid action row - if (!documentDataObject.isPresent() || isValidBulkAction(documentDataObject.get())) { + if (isValidBulkAction(documentDataObject)) { throw new IOException("Invalid request data."); } } @@ -176,15 +171,16 @@ private List> generateEventsFromBulkRequest(final List> optionalDocumentData) { + final BulkActionAndMetadataObject bulkActionAndMetadataObject, + final BulkAPIRequestParams bulkAPIRequestParams, Map optionalDocumentData) { final JacksonEvent.Builder eventBuilder = JacksonEvent.builder().withEventType(EventType.DOCUMENT.toString()); - optionalDocumentData.ifPresent(eventBuilder::withData); + if (optionalDocumentData != null) { + eventBuilder.withData(optionalDocumentData); + } final JacksonEvent event = eventBuilder.build(); - final String index = bulkActionAndMetadataObject.getIndex().isBlank() || bulkActionAndMetadataObject.getIndex().isEmpty() - ? bulkAPIRequestParams.getIndex() : bulkActionAndMetadataObject.getIndex(); + final String index = !StringUtils.isEmpty(bulkAPIRequestParams.getIndex()) ? bulkAPIRequestParams.getIndex() : bulkActionAndMetadataObject.getIndex(); event.getMetadata().setAttribute(BulkAPIEventMetadataKeyAttributes.BULK_API_EVENT_METADATA_ATTRIBUTE_ACTION, bulkActionAndMetadataObject.getAction()); event.getMetadata().setAttribute(BulkAPIEventMetadataKeyAttributes.BULK_API_EVENT_METADATA_ATTRIBUTE_INDEX, index); diff --git a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISource.java b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISource.java index c10b918bb1..1a6c9b4bd1 100644 --- a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISource.java +++ b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISource.java @@ -5,194 +5,31 @@ package org.opensearch.dataprepper.plugins.source.opensearchapi; -import com.linecorp.armeria.server.HttpService; -import com.linecorp.armeria.server.Server; -import com.linecorp.armeria.server.ServerBuilder; -import com.linecorp.armeria.server.encoding.DecodingService; -import com.linecorp.armeria.server.healthcheck.HealthCheckService; -import com.linecorp.armeria.server.throttling.ThrottlingService; -import org.opensearch.dataprepper.HttpRequestExceptionHandler; -import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; -import org.opensearch.dataprepper.http.LogThrottlingRejectHandler; -import org.opensearch.dataprepper.http.LogThrottlingStrategy; +import org.opensearch.dataprepper.http.BaseHttpService; +import org.opensearch.dataprepper.http.BaseHttpSource; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.buffer.Buffer; -import org.opensearch.dataprepper.model.codec.ByteDecoder; -import org.opensearch.dataprepper.model.codec.JsonDecoder; import org.opensearch.dataprepper.model.configuration.PipelineDescription; -import org.opensearch.dataprepper.model.configuration.PluginModel; -import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.Source; -import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; -import org.opensearch.dataprepper.plugins.certificate.model.Certificate; -import org.opensearch.dataprepper.plugins.codec.CompressionOption; -import org.opensearch.dataprepper.http.certificate.CertificateProviderFactory; -import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Collections; -import java.util.Optional; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.function.Function; - @DataPrepperPlugin(name = "opensearch_api", pluginType = Source.class, pluginConfigurationType = OpenSearchAPISourceConfig.class) -public class OpenSearchAPISource implements Source> { - private static final Logger LOG = LoggerFactory.getLogger(OpenSearchAPISource.class); - private static final String PIPELINE_NAME_PLACEHOLDER = "${pipelineName}"; - public static final String REGEX_HEALTH = "regex:^/(?!health$).*$"; - static final String SERVER_CONNECTIONS = "serverConnections"; - - private final OpenSearchAPISourceConfig sourceConfig; - private final CertificateProviderFactory certificateProviderFactory; - private final ArmeriaHttpAuthenticationProvider authenticationProvider; - private final HttpRequestExceptionHandler httpRequestExceptionHandler; - private final String pipelineName; - private Server server; - private final PluginMetrics pluginMetrics; - private static final String HTTP_HEALTH_CHECK_PATH = "/health"; - private ByteDecoder byteDecoder; +public class OpenSearchAPISource extends BaseHttpSource> { + private static final String SOURCE_NAME = "OpenSearch API"; @DataPrepperPluginConstructor public OpenSearchAPISource(final OpenSearchAPISourceConfig sourceConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, - final PipelineDescription pipelineDescription) { - this.sourceConfig = sourceConfig; - this.pluginMetrics = pluginMetrics; - this.pipelineName = pipelineDescription.getPipelineName(); - this.byteDecoder = new JsonDecoder(); - this.certificateProviderFactory = new CertificateProviderFactory(sourceConfig); - final PluginModel authenticationConfiguration = sourceConfig.getAuthentication(); - final PluginSetting authenticationPluginSetting; - - if (authenticationConfiguration == null || authenticationConfiguration.getPluginName().equals(ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME)) { - LOG.warn("Creating OpenSearch API source without authentication. This is not secure."); - LOG.warn("In order to set up Http Basic authentication for the OpenSearch API source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/http-source#authentication-configurations"); - } - - if(authenticationConfiguration != null) { - authenticationPluginSetting = - new PluginSetting(authenticationConfiguration.getPluginName(), authenticationConfiguration.getPluginSettings()); - } else { - authenticationPluginSetting = - new PluginSetting(ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, Collections.emptyMap()); - } - authenticationPluginSetting.setPipelineName(pipelineName); - authenticationProvider = pluginFactory.loadPlugin(ArmeriaHttpAuthenticationProvider.class, authenticationPluginSetting); - httpRequestExceptionHandler = new HttpRequestExceptionHandler(pluginMetrics); - } - - @Override - public void start(final Buffer> buffer) { - if (buffer == null) { - throw new IllegalStateException("Buffer provided is null"); - } - if (server == null) { - final ServerBuilder sb = Server.builder(); - - sb.disableServerHeader(); - - if (sourceConfig.isSsl()) { - LOG.info("Creating http source with SSL/TLS enabled."); - final CertificateProvider certificateProvider = certificateProviderFactory.getCertificateProvider(); - final Certificate certificate = certificateProvider.getCertificate(); - // TODO: enable encrypted key with password - sb.https(sourceConfig.getPort()).tls( - new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8)), - new ByteArrayInputStream(certificate.getPrivateKey().getBytes(StandardCharsets.UTF_8) - ) - ); - } else { - LOG.warn("Creating OpenSearch API source without SSL/TLS. This is not secure."); - LOG.warn("In order to set up TLS for the OpenSearch API source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/http-source#ssl"); - sb.http(sourceConfig.getPort()); - } - - if(sourceConfig.getAuthentication() != null) { - final Optional> optionalAuthDecorator = authenticationProvider.getAuthenticationDecorator(); - - if (sourceConfig.isUnauthenticatedHealthCheck()) { - optionalAuthDecorator.ifPresent(authDecorator -> sb.decorator(REGEX_HEALTH, authDecorator)); - } else { - optionalAuthDecorator.ifPresent(sb::decorator); - } - } - - sb.maxNumConnections(sourceConfig.getMaxConnectionCount()); - sb.requestTimeout(Duration.ofMillis(sourceConfig.getRequestTimeoutInMillis())); - if(sourceConfig.getMaxRequestLength() != null) { - sb.maxRequestLength(sourceConfig.getMaxRequestLength().getBytes()); - } - final int threads = sourceConfig.getThreadCount(); - final ScheduledThreadPoolExecutor blockingTaskExecutor = new ScheduledThreadPoolExecutor(threads); - sb.blockingTaskExecutor(blockingTaskExecutor, true); - final int maxPendingRequests = sourceConfig.getMaxPendingRequests(); - final LogThrottlingStrategy logThrottlingStrategy = new LogThrottlingStrategy( - maxPendingRequests, blockingTaskExecutor.getQueue()); - final LogThrottlingRejectHandler logThrottlingRejectHandler = new LogThrottlingRejectHandler(maxPendingRequests, pluginMetrics); - - final String httpSourcePath = sourceConfig.getPath().replace(PIPELINE_NAME_PLACEHOLDER, pipelineName); - sb.decorator(httpSourcePath, ThrottlingService.newDecorator(logThrottlingStrategy, logThrottlingRejectHandler)); - final OpenSearchAPIService openSearchAPIService = new OpenSearchAPIService(sourceConfig.getBufferTimeoutInMillis(), buffer, pluginMetrics); - - if (CompressionOption.NONE.equals(sourceConfig.getCompression())) { - sb.annotatedService(httpSourcePath, openSearchAPIService, httpRequestExceptionHandler); - } else { - sb.annotatedService(httpSourcePath, openSearchAPIService, DecodingService.newDecorator(), httpRequestExceptionHandler); - } - - if (sourceConfig.hasHealthCheckService()) { - LOG.info("OpenSearch API source health check is enabled"); - sb.service(HTTP_HEALTH_CHECK_PATH, HealthCheckService.builder().longPolling(0).build()); - } - - server = sb.build(); - pluginMetrics.gauge(SERVER_CONNECTIONS, server, Server::numConnections); - } - - try { - server.start().get(); - } catch (ExecutionException ex) { - if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { - throw (RuntimeException) ex.getCause(); - } else { - throw new RuntimeException(ex); - } - } catch (InterruptedException ex) { - Thread.currentThread().interrupt(); - throw new RuntimeException(ex); - } - LOG.info("Started OpenSearch API source on port " + sourceConfig.getPort() + "..."); - } - - @Override - public ByteDecoder getDecoder() { - return byteDecoder; + final PipelineDescription pipelineDescription) { + super(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription, SOURCE_NAME, LoggerFactory.getLogger(OpenSearchAPISource.class)); } @Override - public void stop() { - if (server != null) { - try { - server.stop().get(); - } catch (ExecutionException ex) { - if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { - throw (RuntimeException) ex.getCause(); - } else { - throw new RuntimeException(ex); - } - } catch (InterruptedException ex) { - Thread.currentThread().interrupt(); - throw new RuntimeException(ex); - } - } - LOG.info("Stopped OpenSearch API source."); + public BaseHttpService getHttpService(final int bufferWriteTimeoutInMillis, final Buffer> buffer, final PluginMetrics pluginMetrics) { + return new OpenSearchAPIService(bufferWriteTimeoutInMillis, buffer, pluginMetrics); } } diff --git a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceConfig.java b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceConfig.java index 646ecfe1a1..ed864df778 100644 --- a/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceConfig.java +++ b/data-prepper-plugins/opensearch-api-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceConfig.java @@ -9,7 +9,7 @@ public class OpenSearchAPISourceConfig extends BaseHttpServerConfig { - static final String DEFAULT_ENDPOINT_URI = "/opensearch"; + static final String DEFAULT_ENDPOINT_URI = "/"; static final int DEFAULT_PORT = 9202; @Override diff --git a/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIServiceTest.java b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIServiceTest.java index b6bfbc4932..28a567c4f4 100644 --- a/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIServiceTest.java +++ b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPIServiceTest.java @@ -5,14 +5,6 @@ package org.opensearch.dataprepper.plugins.source.opensearchapi; -import com.linecorp.armeria.server.ServiceRequestContext; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.buffer.Buffer; -import org.opensearch.dataprepper.model.buffer.SizeOverflowException; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.record.Record; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.linecorp.armeria.common.AggregatedHttpRequest; @@ -24,35 +16,46 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.server.ServiceRequestContext; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.stubbing.Answer; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; -import java.util.Map; -import java.util.Optional; import java.util.List; -import java.util.ArrayList; +import java.util.Map; import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class OpenSearchAPIServiceTest { @@ -90,8 +93,7 @@ public void setUp() throws Exception { lenient().when(requestProcessDuration.recordCallable(ArgumentMatchers.>any())).thenAnswer( (Answer) invocation -> { final Object[] args = invocation.getArguments(); - @SuppressWarnings("unchecked") - final Callable callable = (Callable) args[0]; + @SuppressWarnings("unchecked") final Callable callable = (Callable) args[0]; return callable.call(); } ); @@ -101,8 +103,19 @@ public void setUp() throws Exception { } @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void testBulkRequestAPISuccess(boolean testBulkRequestAPIWithIndexInPath) throws Exception { + @CsvSource(value = { + "false, null, null, null", + "false, null, \"\", \"\"", + "false, null, \" \", \" \"", + "false, null, \"pipeline-1\", \"routing-1\"", + "false, \"index-1\", \"pipeline-1\", \"routing-1\"", + "true, null, null, null", + "true, null, \"\", \"\"", + "true, null, \" \", \" \"", + "true, null, \"pipeline-1\", \"routing-1\"", + "true, \"index-1\", \"pipeline-1\", \"routing-1\"" + }) + public void testBulkRequestAPIWithQueryParams(boolean testBulkRequestAPIWithIndexInPath, final String index, final String pipeline, final String routing) throws Exception { AggregatedHttpRequest testRequest; AggregatedHttpResponse postResponse; @@ -111,14 +124,14 @@ public void testBulkRequestAPISuccess(boolean testBulkRequestAPIWithIndexInPath) // Prepare testRequest = generateRandomValidBulkRequestWithNoIndexInBody(2); // When - postResponse = openSearchAPIService.doPostBulkIndex(serviceRequestContext, testRequest, Optional.empty(), - Optional.ofNullable("pipeline-1"), Optional.ofNullable("routing-1")).aggregate().get(); + postResponse = openSearchAPIService.doPostBulkIndex(serviceRequestContext, testRequest, index, + pipeline, routing).aggregate().get(); } else { // Prepare testRequest = generateRandomValidBulkRequest(2); // When postResponse = openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, - Optional.empty(), Optional.empty()).aggregate().get(); + pipeline, routing).aggregate().get(); } // Then @@ -132,19 +145,31 @@ public void testBulkRequestAPISuccess(boolean testBulkRequestAPIWithIndexInPath) } @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void testBulkRequestAPISuccessWithMultipleBulkActions(boolean testBulkRequestAPIWithIndexInPath) throws Exception { + @CsvSource(value = { + "false, null, null, null", + "false, null, \"\", \"\"", + "false, null, \" \", \" \"", + "false, null, \"pipeline-1\", \"routing-1\"", + "false, \"index-1\", \"pipeline-1\", \"routing-1\"", + "true, null, null, null", + "true, null, \"\", \"\"", + "true, null, \" \", \" \"", + "true, null, \"pipeline-1\", \"routing-1\"", + "true, \"index-1\", \"pipeline-1\", \"routing-1\"" + }) + public void testBulkRequestAPISuccessWithMultipleBulkActions(boolean testBulkRequestAPIWithIndexInPath, final String index, final String pipeline, final String routing) throws Exception { // Prepare AggregatedHttpRequest testRequest = generateGoodBulkRequestWithMultipleActions(2); // When + AggregatedHttpResponse postResponse; if (testBulkRequestAPIWithIndexInPath) { - postResponse = openSearchAPIService.doPostBulkIndex(serviceRequestContext, testRequest, Optional.empty(), - Optional.ofNullable("pipeline-1"), Optional.ofNullable("routing-1")).aggregate().get(); + postResponse = openSearchAPIService.doPostBulkIndex(serviceRequestContext, testRequest, index, + pipeline, routing).aggregate().get(); } else { postResponse = openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, - Optional.empty(), Optional.empty()).aggregate().get(); + pipeline, routing).aggregate().get(); } // Then @@ -157,6 +182,80 @@ public void testBulkRequestAPISuccessWithMultipleBulkActions(boolean testBulkReq verify(requestProcessDuration, times(1)).recordCallable(ArgumentMatchers.>any()); } + @ParameterizedTest + @ValueSource(booleans = {false, true}) + public void testBulkRequestAPIWithByteBuffer(boolean testBulkRequestAPIWithIndexInPath) throws Exception { + Buffer> blockingBuffer = mock(BlockingBuffer.class); + OpenSearchAPIService openSearchAPIService = new OpenSearchAPIService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics); + when(blockingBuffer.isByteBuffer()).thenReturn(true); + + AggregatedHttpRequest testRequest = generateGoodBulkRequestWithMultipleActions(2); + if (testBulkRequestAPIWithIndexInPath) { + openSearchAPIService.doPostBulkIndex(serviceRequestContext, testRequest, null, + null, null).aggregate().get(); + } else { + openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, + null, null).aggregate().get(); + } + verify(blockingBuffer, times(1)).writeBytes(eq(testRequest.content().array()), eq(null), eq(TEST_TIMEOUT_IN_MILLIS)); + } + + @Test + public void testBulkRequestAPIEmptyIndex() throws Exception { + RequestHeaders requestHeaders = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/") + .build(); + List jsonList = new ArrayList<>(); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "\t\r\n", "_id", UUID.randomUUID().toString())))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); + HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); + AggregatedHttpRequest testRequest = HttpRequest.of(requestHeaders, httpData).aggregate().get(); + + Buffer> blockingBuffer = mock(BlockingBuffer.class); + OpenSearchAPIService openSearchAPIService = new OpenSearchAPIService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics); + openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, + null, null).aggregate().get(); + } + + @Test + public void testBulkRequestAPIInvalidRequestMissingDocRow() throws Exception { + RequestHeaders requestHeaders = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/") + .build(); + List jsonList = new ArrayList<>(); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "index", "_id", UUID.randomUUID().toString())))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "index", "_id", UUID.randomUUID().toString())))); + HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); + AggregatedHttpRequest testRequest = HttpRequest.of(requestHeaders, httpData).aggregate().get(); + + Buffer> blockingBuffer = mock(BlockingBuffer.class); + OpenSearchAPIService openSearchAPIService = new OpenSearchAPIService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics); + assertThrows(IOException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, + null, null).aggregate().get()); + } + + @Test + public void testBulkRequestAPIInvalidRequestEmptyDocRow() throws Exception { + RequestHeaders requestHeaders = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/") + .build(); + List jsonList = new ArrayList<>(); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "index", "_id", UUID.randomUUID().toString())))); + HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); + AggregatedHttpRequest testRequest = HttpRequest.of(requestHeaders, httpData).aggregate().get(); + + Buffer> blockingBuffer = mock(BlockingBuffer.class); + OpenSearchAPIService openSearchAPIService = new OpenSearchAPIService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics); + assertThrows(IOException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testRequest, + null, null).aggregate().get()); + } + @ParameterizedTest @ValueSource(booleans = {false, true}) public void testBulkRequestAPIBadRequestWithEmpty(boolean testBulkRequestAPIWithIndexInPath) throws Exception { @@ -173,6 +272,16 @@ public void testBulkRequestAPIBadRequestWithInvalidPayload(boolean testBulkReque testBadRequestWithPayload(testBulkRequestAPIWithIndexInPath, String.join("\n", jsonList)); } + @ParameterizedTest + @ValueSource(booleans = {false, true}) + public void testBulkRequestAPIBadRequestWithEmptyMap(boolean testBulkRequestAPIWithIndexInPath) throws Exception { + List jsonList = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + jsonList.add(mapper.writeValueAsString(Collections.emptyMap())); + } + testBadRequestWithPayload(testBulkRequestAPIWithIndexInPath, String.join("\n", jsonList)); + } + @ParameterizedTest @ValueSource(booleans = {false, true}) public void testBulkRequestAPIBadRequestWithInvalidPayload2(boolean testBulkRequestAPIWithIndexInPath) throws Exception { @@ -202,11 +311,11 @@ public void testBulkRequestAPIEntityTooLarge(boolean testBulkRequestAPIWithIndex // When if (testBulkRequestAPIWithIndexInPath) { - assertThrows(SizeOverflowException.class, () -> openSearchAPIService.doPostBulkIndex(serviceRequestContext, testTooLargeRequest, Optional.empty(), - Optional.empty(), Optional.empty()).aggregate().get()); + assertThrows(SizeOverflowException.class, () -> openSearchAPIService.doPostBulkIndex(serviceRequestContext, testTooLargeRequest, null, + null, null).aggregate().get()); } else { - assertThrows(SizeOverflowException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testTooLargeRequest, Optional.empty(), - Optional.empty()).aggregate().get()); + assertThrows(SizeOverflowException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testTooLargeRequest, null, + null).aggregate().get()); } // Then @@ -225,8 +334,8 @@ public void testBulkRequestWithIndexAPIRequestTimeout() throws Exception { lenient().when(serviceRequestContext.isTimedOut()).thenReturn(true); - AggregatedHttpResponse response = openSearchAPIService.doPostBulkIndex(serviceRequestContext, populateDataRequest, Optional.empty(), - Optional.empty(), Optional.empty()).aggregate().get(); + AggregatedHttpResponse response = openSearchAPIService.doPostBulkIndex(serviceRequestContext, populateDataRequest, null, + null, null).aggregate().get(); assertEquals(HttpStatus.REQUEST_TIMEOUT, response.status()); // Then @@ -239,8 +348,8 @@ public void testBulkRequestAPIRequestTimeout() throws Exception { AggregatedHttpRequest populateDataRequest = generateRandomValidBulkRequest(3); lenient().when(serviceRequestContext.isTimedOut()).thenReturn(true); - AggregatedHttpResponse response = openSearchAPIService.doPostBulk(serviceRequestContext, populateDataRequest, Optional.empty(), - Optional.empty()).aggregate().get(); + AggregatedHttpResponse response = openSearchAPIService.doPostBulk(serviceRequestContext, populateDataRequest, null, + null).aggregate().get(); assertEquals(HttpStatus.REQUEST_TIMEOUT, response.status()); // Then @@ -252,18 +361,18 @@ private void testBadRequestWithPayload(boolean testBulkRequestAPIWithIndexInPath RequestHeaders requestHeaders = RequestHeaders.builder() .contentType(MediaType.JSON) .method(HttpMethod.POST) - .path("/opensearch") + .path("/") .build(); HttpData httpData = HttpData.ofUtf8(requestBody); AggregatedHttpRequest testBadRequest = HttpRequest.of(requestHeaders, httpData).aggregate().get(); // When if (testBulkRequestAPIWithIndexInPath) { - assertThrows(IOException.class, () -> openSearchAPIService.doPostBulkIndex(serviceRequestContext, testBadRequest, Optional.empty(), - Optional.empty(), Optional.empty()).aggregate().get()); + assertThrows(IOException.class, () -> openSearchAPIService.doPostBulkIndex(serviceRequestContext, testBadRequest, null, + null, null).aggregate().get()); } else { - assertThrows(IOException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testBadRequest, Optional.empty(), - Optional.empty()).aggregate().get()); + assertThrows(IOException.class, () -> openSearchAPIService.doPostBulk(serviceRequestContext, testBadRequest, null, + null).aggregate().get()); } // Then @@ -280,11 +389,11 @@ private AggregatedHttpRequest generateRandomValidBulkRequest(int numJson) throws RequestHeaders requestHeaders = RequestHeaders.builder() .contentType(MediaType.JSON) .method(HttpMethod.POST) - .path("/opensearch") + .path("/") .build(); List jsonList = new ArrayList<>(); for (int i = 0; i < numJson; i++) { - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "test-index", "_id", "123")))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "test-index", "_id", UUID.randomUUID().toString())))); jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); } HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); @@ -296,11 +405,11 @@ private AggregatedHttpRequest generateRandomValidBulkRequestWithNoIndexInBody(in RequestHeaders requestHeaders = RequestHeaders.builder() .contentType(MediaType.JSON) .method(HttpMethod.POST) - .path("/opensearch") + .path("/") .build(); List jsonList = new ArrayList<>(); for (int i = 0; i < numJson; i++) { - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_id", "123")))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_id", UUID.randomUUID().toString())))); jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); } HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); @@ -311,16 +420,20 @@ private AggregatedHttpRequest generateGoodBulkRequestWithMultipleActions(int num RequestHeaders requestHeaders = RequestHeaders.builder() .contentType(MediaType.JSON) .method(HttpMethod.POST) - .path("/opensearch") + .path("/") .build(); List jsonList = new ArrayList<>(); for (int i = 0; i < numJson; i++) { - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "test-index", "_id", "123")))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "test-index", "_id", UUID.randomUUID().toString())))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", "", "_id", UUID.randomUUID().toString())))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("index", Map.of("_index", " ", "_id", UUID.randomUUID().toString())))); jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("delete", Map.of("_index", "test-index", "_id", "124")))); - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("create", Map.of("_index", "test-index", "_id", "125")))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("delete", Map.of("_index", "test-index", "_id", UUID.randomUUID().toString())))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("create", Map.of("_index", "test-index", "_id", UUID.randomUUID().toString())))); jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); - jsonList.add(mapper.writeValueAsString(Collections.singletonMap("update", Map.of("_index", "test-index", "_id", "126")))); + jsonList.add(mapper.writeValueAsString(Collections.singletonMap("update", Map.of("_index", "test-index", "_id", UUID.randomUUID().toString())))); jsonList.add(mapper.writeValueAsString(Collections.singletonMap("log", UUID.randomUUID().toString()))); } HttpData httpData = HttpData.ofUtf8(String.join("\n", jsonList)); diff --git a/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceTest.java b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceTest.java index f09f5d37e9..0b4553646f 100644 --- a/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceTest.java +++ b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/OpenSearchAPISourceTest.java @@ -21,7 +21,6 @@ import com.linecorp.armeria.server.ServerBuilder; import io.micrometer.core.instrument.Measurement; import io.micrometer.core.instrument.Statistic; -import org.apache.commons.io.IOUtils; import io.netty.util.AsciiString; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -30,14 +29,12 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.HttpRequestExceptionHandler; import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; import org.opensearch.dataprepper.armeria.authentication.HttpBasicAuthenticationConfig; +import org.opensearch.dataprepper.http.BaseHttpSource; import org.opensearch.dataprepper.http.LogThrottlingRejectHandler; import org.opensearch.dataprepper.metrics.MetricNames; import org.opensearch.dataprepper.metrics.MetricsTestUtil; @@ -56,12 +53,8 @@ import org.opensearch.dataprepper.plugins.source.opensearchapi.model.BulkAPIEventMetadataKeyAttributes; import java.io.ByteArrayOutputStream; -import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -70,7 +63,6 @@ import java.util.Map; import java.util.StringJoiner; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; import java.util.zip.GZIPOutputStream; @@ -78,7 +70,6 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -86,11 +77,11 @@ import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class OpenSearchAPISourceTest { + private static final ObjectMapper mapper = new ObjectMapper(); private final String PLUGIN_NAME = "opensearch_api"; private final String TEST_PIPELINE_NAME = "test_pipeline"; private final String TEST_INDEX = "test-index"; @@ -100,12 +91,8 @@ class OpenSearchAPISourceTest { private final int DEFAULT_THREAD_COUNT = 200; private final int MAX_CONNECTIONS_COUNT = 500; private final int MAX_PENDING_REQUESTS_COUNT = 1024; - private final String TEST_SSL_CERTIFICATE_FILE = getClass().getClassLoader().getResource("test_cert.crt").getFile(); private final String TEST_SSL_KEY_FILE = getClass().getClassLoader().getResource("test_decrypted_key.key").getFile(); - - private static final ObjectMapper mapper = new ObjectMapper(); - @Mock private ServerBuilder serverBuilder; @@ -172,7 +159,7 @@ private void refreshMeasurements() { .add(OpenSearchAPIService.PAYLOAD_SIZE).toString()); serverConnectionsMeasurements = MetricsTestUtil.getMeasurementList( new StringJoiner(MetricNames.DELIMITER).add(metricNamePrefix) - .add(OpenSearchAPISource.SERVER_CONNECTIONS).toString()); + .add(BaseHttpSource.SERVER_CONNECTIONS).toString()); } private byte[] createGZipCompressedPayload(final String payload) throws IOException { @@ -256,11 +243,11 @@ public void testHealthCheck() { // When WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority(AUTHORITY) - .method(HttpMethod.GET) - .path("/health") - .build()) + .scheme(SessionProtocol.HTTP) + .authority(AUTHORITY) + .method(HttpMethod.GET) + .path("/health") + .build()) .aggregate() .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); } @@ -271,8 +258,8 @@ public void testHealthCheckUnauthenticatedDisabled() { when(sourceConfig.isUnauthenticatedHealthCheck()).thenReturn(false); when(sourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", Map.of( - "username", "test", - "password", "test" + "username", "test", + "password", "test" ))); pluginMetrics = PluginMetrics.fromNames(PLUGIN_NAME, TEST_PIPELINE_NAME); openSearchAPISource = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); @@ -300,13 +287,13 @@ public void testBulkRequestJsonResponse400WithEmptyPayload(boolean includeIndexI // When WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority(AUTHORITY) - .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/" + TEST_INDEX + "/_bulk" :"/opensearch/_bulk") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.ofUtf8(testBadData)) + .scheme(SessionProtocol.HTTP) + .authority(AUTHORITY) + .method(HttpMethod.POST) + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.ofUtf8(testBadData)) .aggregate() .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.BAD_REQUEST)).join(); @@ -338,7 +325,7 @@ public void testBulkRequestJsonResponse400WithInvalidPayload(boolean includeInde .scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/" + TEST_INDEX + "/_bulk" :"/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .contentType(MediaType.JSON_UTF_8) .build(), HttpData.ofUtf8(testBadData)) @@ -385,7 +372,7 @@ private void testBulkRequestAPI200(boolean includeIndexInPath, boolean useCompre .scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/" + TEST_INDEX + "/_bulk" :"/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .add(HttpHeaderNames.CONTENT_ENCODING, "gzip") .build(), createGZipCompressedPayload(testData)) @@ -396,7 +383,7 @@ private void testBulkRequestAPI200(boolean includeIndexInPath, boolean useCompre .scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/" + TEST_INDEX + "/_bulk" : "/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .contentType(MediaType.JSON_UTF_8) .build(), HttpData.ofUtf8(testData)) @@ -463,7 +450,7 @@ private void testBulkRequestJsonResponse408(boolean includeIndexInPath) throws J final RequestHeaders testRequestHeaders = RequestHeaders.builder().scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath? "/opensearch/"+TEST_INDEX+"/_bulk" : "/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .contentType(MediaType.JSON_UTF_8) .build(); final HttpData testHttpData = HttpData.ofUtf8(generateTestData(includeIndexInPath, 1)); @@ -507,7 +494,7 @@ private void testBulkRequestJsonResponse413(boolean includeIndexInPath) throws J .scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/"+TEST_INDEX+"/_bulk" : "/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .contentType(MediaType.JSON_UTF_8) .build(), HttpData.ofUtf8(testData)) @@ -552,7 +539,7 @@ public void testOpenSearchAPISourceServerConnectionsMetric(boolean includeIndexI final RequestHeaders testRequestHeaders = RequestHeaders.builder().scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/"+TEST_INDEX+"/_bulk" : "/opensearch/_bulk") + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") .contentType(MediaType.JSON_UTF_8) .build(); final HttpData testHttpData = HttpData.ofUtf8(generateTestData(includeIndexInPath, 1)); @@ -566,34 +553,6 @@ public void testOpenSearchAPISourceServerConnectionsMetric(boolean includeIndexI Assertions.assertEquals(1.0, serverConnectionsMeasurement.getValue()); } - @Test - public void testOpenSearchAPISourceServerStartCertFileSuccess() throws IOException { - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - when(server.stop()).thenReturn(completableFuture); - - final Path certFilePath = new File(TEST_SSL_CERTIFICATE_FILE).toPath(); - final Path keyFilePath = new File(TEST_SSL_KEY_FILE).toPath(); - final String certAsString = Files.readString(certFilePath); - final String keyAsString = Files.readString(keyFilePath); - - when(sourceConfig.isSsl()).thenReturn(true); - when(sourceConfig.getSslCertificateFile()).thenReturn(TEST_SSL_CERTIFICATE_FILE); - when(sourceConfig.getSslKeyFile()).thenReturn(TEST_SSL_KEY_FILE); - openSearchAPISource = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - openSearchAPISource.start(testBuffer); - openSearchAPISource.stop(); - - final ArgumentCaptor certificateIs = ArgumentCaptor.forClass(InputStream.class); - final ArgumentCaptor privateKeyIs = ArgumentCaptor.forClass(InputStream.class); - verify(serverBuilder).tls(certificateIs.capture(), privateKeyIs.capture()); - final String actualCertificate = IOUtils.toString(certificateIs.getValue(), StandardCharsets.UTF_8.name()); - final String actualPrivateKey = IOUtils.toString(privateKeyIs.getValue(), StandardCharsets.UTF_8.name()); - assertThat(actualCertificate, is(certAsString)); - assertThat(actualPrivateKey, is(keyAsString)); - } - } - @ParameterizedTest @ValueSource(booleans = {false, true}) void testBulkRequestAPIJsonResponse(boolean includeIndexInPath) throws JsonProcessingException { @@ -613,131 +572,17 @@ void testBulkRequestAPIJsonResponse(boolean includeIndexInPath) throws JsonProce openSearchAPISource.start(testBuffer); WebClient.builder().factory(ClientFactory.insecure()).build().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTPS) - .authority(AUTHORITY) - .method(HttpMethod.POST) - .path(includeIndexInPath ? "/opensearch/"+TEST_INDEX+"/_bulk" : "/opensearch/_bulk") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.ofUtf8(generateTestData(includeIndexInPath, 1))) + .scheme(SessionProtocol.HTTPS) + .authority(AUTHORITY) + .method(HttpMethod.POST) + .path(includeIndexInPath ? "/" + TEST_INDEX + "/_bulk" : "/_bulk") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.ofUtf8(generateTestData(includeIndexInPath, 1))) .aggregate() .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); } - @Test - public void testDoubleStart() { - // starting server - openSearchAPISource.start(testBuffer); - // double start server - Assertions.assertThrows(IllegalStateException.class, () -> openSearchAPISource.start(testBuffer)); - } - - @Test - public void testStartWithEmptyBuffer() { - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - Assertions.assertThrows(IllegalStateException.class, () -> source.start(null)); - } - - @Test - public void testStartWithServerExecutionExceptionNoCause() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - when(completableFuture.get()).thenThrow(new ExecutionException("", null)); - - // When/Then - Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); - } - } - - @Test - public void testStartWithServerExecutionExceptionWithCause() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - final NullPointerException expCause = new NullPointerException(); - when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); - - // When/Then - final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); - Assertions.assertEquals(expCause, ex); - } - } - - @Test - public void testStartWithInterruptedException() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - when(completableFuture.get()).thenThrow(new InterruptedException()); - - // When/Then - Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); - Assertions.assertTrue(Thread.interrupted()); - } - } - - @Test - public void testStopWithServerExecutionExceptionNoCause() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - source.start(testBuffer); - when(server.stop()).thenReturn(completableFuture); - - // When/Then - when(completableFuture.get()).thenThrow(new ExecutionException("", null)); - Assertions.assertThrows(RuntimeException.class, source::stop); - } - } - - @Test - public void testStopWithServerExecutionExceptionWithCause() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - source.start(testBuffer); - when(server.stop()).thenReturn(completableFuture); - final NullPointerException expCause = new NullPointerException(); - when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); - - // When/Then - final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, source::stop); - Assertions.assertEquals(expCause, ex); - } - } - - @Test - public void testStopWithInterruptedException() throws ExecutionException, InterruptedException { - // Prepare - final OpenSearchAPISource source = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - source.start(testBuffer); - when(server.stop()).thenReturn(completableFuture); - when(completableFuture.get()).thenThrow(new InterruptedException()); - - // When/Then - Assertions.assertThrows(RuntimeException.class, source::stop); - Assertions.assertTrue(Thread.interrupted()); - } - } - - @Test - public void testRunAnotherSourceWithSamePort() { - // starting server - openSearchAPISource.start(testBuffer); - - final OpenSearchAPISource secondSource = new OpenSearchAPISource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - //Expect RuntimeException because when port is already in use, BindException is thrown which is not RuntimeException - Assertions.assertThrows(RuntimeException.class, () -> secondSource.start(testBuffer)); - } - @Test public void request_that_exceeds_maxRequestLength_returns_413() throws JsonProcessingException { reset(sourceConfig); @@ -765,10 +610,10 @@ public void request_that_exceeds_maxRequestLength_returns_413() throws JsonProce .scheme(SessionProtocol.HTTP) .authority(AUTHORITY) .method(HttpMethod.POST) - .path("/opensearch") + .path("/") .contentType(MediaType.JSON_UTF_8) .build(), - HttpData.ofUtf8(testData)) + HttpData.ofUtf8(testData)) .aggregate() .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.REQUEST_ENTITY_TOO_LARGE)).join(); diff --git a/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/model/BulkAPIRequestParamsTest.java b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/model/BulkAPIRequestParamsTest.java new file mode 100644 index 0000000000..a448b8af1f --- /dev/null +++ b/data-prepper-plugins/opensearch-api-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearchapi/model/BulkAPIRequestParamsTest.java @@ -0,0 +1,42 @@ +package org.opensearch.dataprepper.plugins.source.opensearchapi.model; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class BulkAPIRequestParamsTest { + + private static final String testIndex = "test-index"; + private static final String testPipeline = "test-pipeline"; + private static final String testRouting = "test-routing"; + + @Test + public void testValidObjectCreated() { + BulkAPIRequestParams bulkAPIRequestParams = BulkAPIRequestParams.builder().build(); + assertNull(bulkAPIRequestParams.getIndex()); + assertNull(bulkAPIRequestParams.getPipeline()); + assertNull(bulkAPIRequestParams.getRouting()); + } + + @Test + public void testValidObjectCreatedWithNonNullFields() { + BulkAPIRequestParams bulkAPIRequestParams = BulkAPIRequestParams.builder() + .index(testIndex) + .pipeline(testPipeline) + .routing(testRouting) + .build(); + assertEquals(testIndex, bulkAPIRequestParams.getIndex()); + assertEquals(testPipeline, bulkAPIRequestParams.getPipeline()); + assertEquals(testRouting, bulkAPIRequestParams.getRouting()); + } + + @Test + public void testValidObjectCreatedWithNonNullFields2() { + BulkAPIRequestParams bulkAPIRequestParams = new BulkAPIRequestParams(testIndex, testPipeline, testRouting); + assertEquals(testIndex, bulkAPIRequestParams.getIndex()); + assertEquals(testPipeline, bulkAPIRequestParams.getPipeline()); + assertEquals(testRouting, bulkAPIRequestParams.getRouting()); + } + +}