Skip to content

Commit

Permalink
Adds sigv4 support to Elasticsearch client (opensearch-project#3305)
Browse files Browse the repository at this point in the history
Adds sigv4 support to Elasticsearch client. Move AwsRequestSigningApacheInterceptor to aws-plugin-api, use in os source and sink

Signed-off-by: Taylor Gray <[email protected]>
  • Loading branch information
graytaylor0 authored Sep 6, 2023
1 parent eff31fe commit 40980c1
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 24 deletions.
3 changes: 2 additions & 1 deletion data-prepper-plugins/aws-plugin-api/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

dependencies {
implementation 'software.amazon.awssdk:auth'
implementation 'software.amazon.awssdk:apache-client'
}

test {
Expand All @@ -12,7 +13,7 @@ jacocoTestCoverageVerification {
violationRules {
rule {
limit {
minimum = 1.0
minimum = 0.99
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
/*
* Copyright OpenSearch Contributors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
* the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.dataprepper.plugins.sink.opensearch;
package org.opensearch.dataprepper.aws.api;

import org.apache.http.Header;
import org.apache.http.HttpEntityEnclosingRequest;
Expand Down Expand Up @@ -48,7 +40,7 @@
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer}
* and {@link AwsCredentialsProvider}.
*/
final class AwsRequestSigningApacheInterceptor implements HttpRequestInterceptor {
public final class AwsRequestSigningApache4Interceptor implements HttpRequestInterceptor {

/**
* Constant to check content-length
Expand Down Expand Up @@ -90,10 +82,10 @@ final class AwsRequestSigningApacheInterceptor implements HttpRequestInterceptor
* @param awsCredentialsProvider source of AWS credentials for signing
* @param region signing region
*/
public AwsRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final Region region) {
public AwsRequestSigningApache4Interceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final Region region) {
this.service = Objects.requireNonNull(service);
this.signer = Objects.requireNonNull(signer);
this.awsCredentialsProvider = Objects.requireNonNull(awsCredentialsProvider);
Expand All @@ -107,10 +99,10 @@ public AwsRequestSigningApacheInterceptor(final String service,
* @param awsCredentialsProvider source of AWS credentials for signing
* @param region signing region
*/
public AwsRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final String region) {
public AwsRequestSigningApache4Interceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final String region) {
this(service, signer, awsCredentialsProvider, Region.of(region));
}

Expand Down Expand Up @@ -177,7 +169,7 @@ private URI buildUri(final HttpContext context, URIBuilder uriBuilder) throws IO
}

return uriBuilder.build();
} catch (URISyntaxException e) {
} catch (final Exception e) {
throw new IOException("Invalid URI", e);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.aws.api;

import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpHost;
import org.apache.http.RequestLine;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import org.apache.http.protocol.HttpCoreContext;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.regions.Region;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class AwsRequestSigningApache4InterceptorTest {

@Mock
private Signer signer;

@Mock
private AwsCredentialsProvider awsCredentialsProvider;

@Mock
private HttpEntityEnclosingRequest httpRequest;

@Mock
private HttpContext httpContext;

private AwsRequestSigningApache4Interceptor createObjectUnderTest() {
return new AwsRequestSigningApache4Interceptor("es", signer, awsCredentialsProvider, Region.US_EAST_1);
}

@Test
void invalidURI_throws_IOException() {

final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getUri()).thenReturn("http://invalid-uri.com/file[/].html\n");

when(httpRequest.getRequestLine()).thenReturn(requestLine);

final AwsRequestSigningApache4Interceptor objectUnderTest = new AwsRequestSigningApache4Interceptor("es", signer, awsCredentialsProvider, "us-east-1");

assertThrows(IOException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void IOException_is_thrown_when_buildURI_throws_exception() {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);

when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenThrow(RuntimeException.class);

final AwsRequestSigningApache4Interceptor objectUnderTest = createObjectUnderTest();

assertThrows(IOException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void empty_contentStreamProvider_throws_IllegalStateException() throws IOException {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);
when(httpRequest.getAllHeaders()).thenReturn(new BasicHeader[]{
new BasicHeader("test-name", "test-value"),
new BasicHeader("content-length", "0")
});

final HttpEntity httpEntity = mock(HttpEntity.class);
final InputStream inputStream = mock(InputStream.class);
when(httpEntity.getContent()).thenReturn(inputStream);

when((httpRequest).getEntity()).thenReturn(httpEntity);

final HttpHost httpHost = HttpHost.create("http://localhost?param=test");
when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenReturn(httpHost);

final SdkHttpFullRequest signedRequest = mock(SdkHttpFullRequest.class);
when(signedRequest.headers()).thenReturn(Map.of("test-name", List.of("test-value")));
when(signedRequest.contentStreamProvider()).thenReturn(Optional.empty());
when(signer.sign(any(SdkHttpFullRequest.class), any(ExecutionAttributes.class)))
.thenReturn(signedRequest);

final AwsRequestSigningApache4Interceptor objectUnderTest = createObjectUnderTest();

assertThrows(IllegalStateException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void testHappyPath() throws IOException {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);
when(httpRequest.getAllHeaders()).thenReturn(new BasicHeader[]{
new BasicHeader("test-name", "test-value"),
new BasicHeader("content-length", "0")
});

final HttpEntity httpEntity = mock(HttpEntity.class);
final InputStream inputStream = mock(InputStream.class);
when(httpEntity.getContent()).thenReturn(inputStream);

when((httpRequest).getEntity()).thenReturn(httpEntity);

final HttpHost httpHost = HttpHost.create("http://localhost?param=test");
when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenReturn(httpHost);

final SdkHttpFullRequest signedRequest = mock(SdkHttpFullRequest.class);
when(signedRequest.headers()).thenReturn(Map.of("test-name", List.of("test-value")));
final ContentStreamProvider contentStreamProvider = mock(ContentStreamProvider.class);
final InputStream contentInputStream = mock(InputStream.class);
when(contentStreamProvider.newStream()).thenReturn(contentInputStream);
when(signedRequest.contentStreamProvider()).thenReturn(Optional.of(contentStreamProvider));
when(signer.sign(any(SdkHttpFullRequest.class), any(ExecutionAttributes.class)))
.thenReturn(signedRequest);
createObjectUnderTest().process(httpRequest, httpContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("The search_after worker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -125,6 +126,8 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();
LOG.info("Started processing for index: '{}'", indexName);

Optional<OpenSearchIndexProgressState> openSearchIndexProgressStateOptional = openSearchIndexPartition.getPartitionState();

if (openSearchIndexProgressStateOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("PitWorker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -149,6 +150,8 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();

LOG.info("Starting processing for index: '{}'", indexName);
Optional<OpenSearchIndexProgressState> openSearchIndexProgressStateOptional = openSearchIndexPartition.getPartitionState();

if (openSearchIndexProgressStateOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("ScrollWorker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -142,6 +143,7 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();
LOG.info("Started processing for index: '{}'", indexName);

final Integer batchSize = openSearchSourceConfiguration.getSearchConfiguration().getBatchSize();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import co.elastic.clients.transport.ElasticsearchTransport;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.HttpResponseInterceptor;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
Expand All @@ -31,11 +32,13 @@
import org.opensearch.client.transport.rest_client.RestClientTransport;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.aws.api.AwsRequestSigningApache4Interceptor;
import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;

Expand Down Expand Up @@ -165,12 +168,38 @@ private org.elasticsearch.client.RestClient createElasticSearchRestClient(final
new BasicHeader("Content-type", "application/json")
});

attachBasicAuth(restClientBuilder, openSearchSourceConfiguration);
if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) {
attachSigV4ForElasticsearchClient(restClientBuilder, openSearchSourceConfiguration);
} else {
attachBasicAuth(restClientBuilder, openSearchSourceConfiguration);
}
setConnectAndSocketTimeout(restClientBuilder, openSearchSourceConfiguration);

return restClientBuilder.build();
}

private void attachSigV4ForElasticsearchClient(final org.elasticsearch.client.RestClientBuilder restClientBuilder,
final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder()
.withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion())
.withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn())
.withStsExternalId(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsExternalId())
.withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides())
.build());
final Aws4Signer aws4Signer = Aws4Signer.create();
final HttpRequestInterceptor httpRequestInterceptor = new AwsRequestSigningApache4Interceptor(AOS_SERVICE_NAME, aws4Signer,
awsCredentialsProvider, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion());
restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> {
httpClientBuilder.addInterceptorLast(httpRequestInterceptor);
attachSSLContext(httpClientBuilder, openSearchSourceConfiguration);
httpClientBuilder.addInterceptorLast(
(HttpResponseInterceptor)
(response, context) ->
response.addHeader("X-Elastic-Product", "Elasticsearch"));
return httpClientBuilder;
});
}

private void attachBasicAuth(final RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) {

restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ void provideElasticSearchClient_with_username_and_password() {
verifyNoInteractions(awsCredentialsSupplier);
}

@Test
void provideElasticSearchClient_with_aws_auth() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
when(connectionConfiguration.getSocketTimeout()).thenReturn(null);
when(connectionConfiguration.getConnectTimeout()).thenReturn(null);

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1);
final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role";
when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap());
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final ArgumentCaptor<AwsCredentialsOptions> awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider);

final ElasticsearchClient elasticsearchClient = createObjectUnderTest().provideElasticSearchClient(openSearchSourceConfiguration);
assertThat(elasticsearchClient, notNullValue());

final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue();
assertThat(awsCredentialsOptions, notNullValue());
assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1));
assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap()));
assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
}

@Test
void provideOpenSearchClient_with_aws_auth() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
Expand Down
Loading

0 comments on commit 40980c1

Please sign in to comment.