Skip to content

Commit

Permalink
Add s3 sink client options (#4959)
Browse files Browse the repository at this point in the history
* Add s3 sink client options

Signed-off-by: Hai Yan <[email protected]>

* Add upper bounds for new options

Signed-off-by: Hai Yan <[email protected]>

---------

Signed-off-by: Hai Yan <[email protected]>
  • Loading branch information
oeyh committed Sep 20, 2024
1 parent a3bd538 commit aaef847
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 3 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/s3-sink/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies {
implementation 'joda-time:joda-time:2.12.7'
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final'
implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-csv'
implementation 'software.amazon.awssdk:netty-nio-client'
implementation 'software.amazon.awssdk:s3'
implementation 'software.amazon.awssdk:sts'
implementation 'software.amazon.awssdk:securitylake:2.26.18'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
Expand All @@ -31,10 +35,21 @@ static S3AsyncClient createS3AsyncClient(final S3SinkConfig s3SinkConfig, final
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3AsyncClient.builder()
S3AsyncClientBuilder s3AsyncClientBuilder = S3AsyncClient.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig));

if (s3SinkConfig.getClientOptions() != null) {
final ClientOptions clientOptions = s3SinkConfig.getClientOptions();
SdkAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
.connectionAcquisitionTimeout(clientOptions.getAcquireTimeout())
.maxConcurrency(clientOptions.getMaxConnections())
.build();
s3AsyncClientBuilder.httpClient(httpClient);
}

return s3AsyncClientBuilder.build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.dataprepper.plugins.sink.s3.compression.CompressionOption;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.AggregateThresholdOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.ObjectKeyOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.ThresholdOptions;

Expand Down Expand Up @@ -95,6 +96,9 @@ private boolean isValidBucketConfig() {
@AwsAccountId
private String defaultBucketOwner;

@JsonProperty("client")
private ClientOptions clientOptions;

/**
* Aws Authentication configuration Options.
* @return aws authentication options.
Expand Down Expand Up @@ -195,4 +199,8 @@ public Map<String, String> getBucketOwners() {
public String getDefaultBucketOwner() {
return defaultBucketOwner;
}

public ClientOptions getClientOptions() {
return clientOptions;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.s3.configuration;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import org.hibernate.validator.constraints.time.DurationMax;
import org.hibernate.validator.constraints.time.DurationMin;

import java.time.Duration;

public class ClientOptions {
private static final int DEFAULT_MAX_CONNECTIONS = 50;
private static final Duration DEFAULT_ACQUIRE_TIMEOUT = Duration.ofSeconds(10);

@JsonProperty("max_connections")
@Min(1)
@Max(5000)
private int maxConnections = DEFAULT_MAX_CONNECTIONS;

@JsonProperty("acquire_timeout")
@DurationMin(seconds = 1)
@DurationMax(seconds = 3600)
private Duration acquireTimeout = DEFAULT_ACQUIRE_TIMEOUT;

public int getMaxConnections() {
return maxConnections;
}

public Duration getAcquireTimeout() {
return acquireTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;

import java.time.Duration;
import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
Expand All @@ -44,14 +51,16 @@ class ClientFactoryTest {

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;
@Mock
private ClientOptions clientOptions;

@BeforeEach
void setUp() {
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
}

@Test
void createS3Client_with_real_S3Client() {
void createS3AsyncClient_with_real_S3AsyncClient() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1);
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);

Expand Down Expand Up @@ -99,4 +108,66 @@ void createS3Client_provides_correct_inputs(final String regionString) {
assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(externalId));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}

@Test
void createS3AsyncClient_with_client_options_returns_expected_client() {
final Region region = Region.of("us-east-1");
final String stsRoleArn = UUID.randomUUID().toString();
final String externalId = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);

final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider);

final S3AsyncClientBuilder s3AsyncClientBuilder = mock(S3AsyncClientBuilder.class);
when(s3AsyncClientBuilder.region(region)).thenReturn(s3AsyncClientBuilder);
when(s3AsyncClientBuilder.credentialsProvider(any())).thenReturn(s3AsyncClientBuilder);
when(s3AsyncClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3AsyncClientBuilder);

when(s3SinkConfig.getClientOptions()).thenReturn(clientOptions);

final int maxConnections = 100;
final Duration acquireTimeout = Duration.ofSeconds(30);
when(clientOptions.getMaxConnections()).thenReturn(maxConnections);
when(clientOptions.getAcquireTimeout()).thenReturn(acquireTimeout);

final NettyNioAsyncHttpClient.Builder httpClientBuilder = mock(NettyNioAsyncHttpClient.Builder.class);
final SdkAsyncHttpClient httpClient = mock(SdkAsyncHttpClient.class);
when(httpClientBuilder.connectionAcquisitionTimeout(any(Duration.class))).thenReturn(httpClientBuilder);
when(httpClientBuilder.maxConcurrency(anyInt())).thenReturn(httpClientBuilder);
when(httpClientBuilder.build()).thenReturn(httpClient);

try(final MockedStatic<S3AsyncClient> s3AsyncClientMockedStatic = mockStatic(S3AsyncClient.class);
final MockedStatic<NettyNioAsyncHttpClient> httpClientMockedStatic = mockStatic(NettyNioAsyncHttpClient.class)) {
s3AsyncClientMockedStatic.when(S3AsyncClient::builder)
.thenReturn(s3AsyncClientBuilder);
httpClientMockedStatic.when(NettyNioAsyncHttpClient::builder)
.thenReturn(httpClientBuilder);
ClientFactory.createS3AsyncClient(s3SinkConfig, awsCredentialsSupplier);
}

final ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(s3AsyncClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());

final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();

assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider));

final ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions.getRegion(), equalTo(region));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(externalId));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));

verify(httpClientBuilder).connectionAcquisitionTimeout(acquireTimeout);
verify(httpClientBuilder).maxConcurrency(maxConnections);
verify(s3AsyncClientBuilder).httpClient(httpClient);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,9 @@ void get_AWS_Auth_options_in_sinkconfig_exception() {
void get_json_codec_test() {
assertNull(new S3SinkConfig().getCodec());
}

@Test
void get_client_option_test() {
assertNull(new S3SinkConfig().getClientOptions());
}
}

0 comments on commit aaef847

Please sign in to comment.