Skip to content

Commit

Permalink
Update tests and changes per review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Souvik Bose <[email protected]>
  • Loading branch information
sbose2k21 committed Aug 28, 2024
1 parent 93ababc commit 1ca0077
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.common.KinesisClientUtil;

public class ClientFactory {
public class KinesisClientFactory {
private final AwsCredentialsProvider awsCredentialsProvider;
private final AwsAuthenticationConfig awsAuthenticationConfig;

public ClientFactory(final AwsCredentialsSupplier awsCredentialsSupplier,
final AwsAuthenticationConfig awsAuthenticationConfig) {
public KinesisClientFactory(final AwsCredentialsSupplier awsCredentialsSupplier,
final AwsAuthenticationConfig awsAuthenticationConfig) {
awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationConfig.getAwsRegion())
.withStsRoleArn(awsAuthenticationConfig.getAwsStsRoleArn())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class KinesisService {
private final ExecutorService executorService;

public KinesisService(final KinesisSourceConfig sourceConfig,
final ClientFactory clientFactory,
final KinesisClientFactory kinesisClientFactory,
final PluginMetrics pluginMetrics,
final PluginFactory pluginFactory,
final PipelineDescription pipelineDescription,
Expand All @@ -72,9 +72,9 @@ public KinesisService(final KinesisSourceConfig sourceConfig,
kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().get();
this.tableName = kinesisLeaseConfig.getLeaseCoordinationTable().getTableName();
this.kclMetricsNamespaceName = this.tableName;
this.dynamoDbClient = clientFactory.buildDynamoDBClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion());
this.kinesisClient = clientFactory.buildKinesisAsyncClient();
this.cloudWatchClient = clientFactory.buildCloudWatchAsyncClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion());
this.dynamoDbClient = kinesisClientFactory.buildDynamoDBClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion());
this.kinesisClient = kinesisClientFactory.buildKinesisAsyncClient();
this.cloudWatchClient = kinesisClientFactory.buildCloudWatchAsyncClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion());
this.pipelineName = pipelineDescription.getPipelineName();
this.applicationName = pipelineName;
this.executorService = Executors.newFixedThreadPool(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public KinesisSource(final KinesisSourceConfig kinesisSourceConfig,
final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier) {
this.kinesisSourceConfig = kinesisSourceConfig;
this.kinesisLeaseConfigSupplier = kinesisLeaseConfigSupplier;
ClientFactory clientFactory = new ClientFactory(awsCredentialsSupplier, kinesisSourceConfig.getAwsAuthenticationConfig());
this.kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory,
KinesisClientFactory kinesisClientFactory = new KinesisClientFactory(awsCredentialsSupplier, kinesisSourceConfig.getAwsAuthenticationConfig());
this.kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory,
pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier);
}
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Size;
import lombok.Getter;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
Expand All @@ -26,14 +27,17 @@ public class AwsAuthenticationConfig {
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;

@Getter
@JsonProperty("sts_role_arn")
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;

@Getter
@JsonProperty("sts_external_id")
@Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters")
private String awsStsExternalId;

@Getter
@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;
Expand All @@ -42,18 +46,6 @@ public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

public String getAwsStsExternalId() {
return awsStsExternalId;
}

public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}

public AwsCredentialsProvider authenticateAwsConfiguration() {

final AwsCredentialsProvider awsCredentialsProvider;
Expand All @@ -67,7 +59,7 @@ public AwsCredentialsProvider authenticateAwsConfiguration() {
final StsClient stsClient = StsClient.builder().region(getAwsRegion()).build();

AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder()
.roleSessionName("GeoIP-Processor-" + UUID.randomUUID()).roleArn(awsStsRoleArn);
.roleSessionName("Kinesis-source-" + UUID.randomUUID()).roleArn(awsStsRoleArn);

if (awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
assumeRoleRequestBuilder = assumeRoleRequestBuilder.overrideConfiguration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

public class KinesisRecordProcessor implements ShardRecordProcessor {
Expand All @@ -52,7 +53,7 @@ public class KinesisRecordProcessor implements ShardRecordProcessor {
private final Counter recordProcessingErrors;
private final Counter checkpointFailures;
private static final Duration ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofSeconds(20);
private static final String ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME = "acknowledgementSetCallbackCounter";
public static final String ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME = "acknowledgementSetCallbackCounter";
public static final String KINESIS_RECORD_PROCESSING_ERRORS = "recordProcessingErrors";
public static final String KINESIS_CHECKPOINT_FAILURES = "checkpointFailures";
public static final String KINESIS_STREAM_TAG_KEY = "stream";
Expand All @@ -71,7 +72,7 @@ public KinesisRecordProcessor(Buffer<Record<Event>> buffer,
final PluginSetting codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), codecConfiguration.getPluginSettings());
this.codec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSettings);
this.acknowledgementSetManager = acknowledgementSetManager;
this.acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME);
this.acknowledgementSetCallbackCounter = pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName());
this.recordProcessingErrors = pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSING_ERRORS, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName());
this.checkpointFailures = pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName());
this.checkpointIntervalMilliSeconds = kinesisStreamConfig.getCheckPointIntervalInMilliseconds();
Expand Down Expand Up @@ -108,29 +109,21 @@ public void processRecords(ProcessRecordsInput processRecordsInput) {
List<Record<Event>> records = new ArrayList<>();

try {
AcknowledgementSet acknowledgementSet;
Optional<AcknowledgementSet> acknowledgementSetOpt = Optional.empty();
boolean acknowledgementsEnabled = kinesisSourceConfig.isAcknowledgments();
if (acknowledgementsEnabled) {
acknowledgementSet = createAcknowledgmentSet(processRecordsInput);
} else {
acknowledgementSet = null;
acknowledgementSetOpt = Optional.of(createAcknowledgmentSet(processRecordsInput));
}

for (KinesisClientRecord record : processRecordsInput.records()) {
processRecord(record, records::add);
}

if (acknowledgementSet != null) {
records.forEach(record -> {
acknowledgementSet.add(record.getData());
});
}
acknowledgementSetOpt.ifPresent(acknowledgementSet -> records.forEach(record -> acknowledgementSet.add(record.getData())));

buffer.writeAll(records, bufferTimeoutMillis);

if (acknowledgementSet != null) {
acknowledgementSet.complete();
}
acknowledgementSetOpt.ifPresent(AcknowledgementSet::complete);

// Checkpoint for shard
if (kinesisStreamConfig.isEnableCheckPoint() && System.currentTimeMillis() - lastCheckpointTimeInMillis > checkpointIntervalMilliSeconds) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.dataprepper.pipeline.parser.ByteCountDeserializer;
import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationDeserializer;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import software.amazon.awssdk.regions.Region;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -48,7 +49,7 @@ void testConfigWithTestExtension() throws IOException {
assertNotNull(kinesisLeaseConfig.getLeaseCoordinationTable());
assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getTableName(), "kinesis-pipeline-kcl");
assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getRegion(), "us-east-1");

assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion(), Region.US_EAST_1);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.mock;

public class ClientFactoryTest {
public class KinesisClientFactoryTest {
private Region region = Region.US_EAST_1;
private String roleArn;
private Map<String, String> stsHeader;
Expand All @@ -33,7 +33,7 @@ void testCreateClient() throws NoSuchFieldException, IllegalAccessException {
ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1");
ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", roleArn);

ClientFactory clientFactory = new ClientFactory(awsCredentialsSupplier, awsAuthenticationOptionsConfig);
KinesisClientFactory clientFactory = new KinesisClientFactory(awsCredentialsSupplier, awsAuthenticationOptionsConfig);

final DynamoDbAsyncClient dynamoDbAsyncClient = clientFactory.buildDynamoDBClient(Region.US_EAST_1);
assertNotNull(dynamoDbAsyncClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public class KinesisServiceTest {
private PipelineDescription pipelineDescription;

@Mock
private ClientFactory clientFactory;
private KinesisClientFactory kinesisClientFactory;

@Mock
private KinesisAsyncClient kinesisClient;
Expand Down Expand Up @@ -120,7 +120,7 @@ void setup() {
kinesisClient = mock(KinesisAsyncClient.class);
dynamoDbClient = mock(DynamoDbAsyncClient.class);
cloudWatchClient = mock(CloudWatchAsyncClient.class);
clientFactory = mock(ClientFactory.class);
kinesisClientFactory = mock(KinesisClientFactory.class);
scheduler = mock(Scheduler.class);
pipelineDescription = mock(PipelineDescription.class);
buffer = mock(Buffer.class);
Expand Down Expand Up @@ -168,16 +168,16 @@ void setup() {
when(kinesisSourceConfig.getStreams()).thenReturn(streamConfigs);
when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(NUMBER_OF_RECORDS_TO_ACCUMULATE);

when(clientFactory.buildDynamoDBClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(dynamoDbClient);
when(clientFactory.buildKinesisAsyncClient()).thenReturn(kinesisClient);
when(clientFactory.buildCloudWatchAsyncClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(cloudWatchClient);
when(kinesisClientFactory.buildDynamoDBClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(dynamoDbClient);
when(kinesisClientFactory.buildKinesisAsyncClient()).thenReturn(kinesisClient);
when(kinesisClientFactory.buildCloudWatchAsyncClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(cloudWatchClient);
when(kinesisClient.serviceClientConfiguration()).thenReturn(KinesisServiceClientConfiguration.builder().region(Region.US_EAST_1).build());
when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.completedFuture(true));
when(pipelineDescription.getPipelineName()).thenReturn(PIPELINE_NAME);
}

public KinesisService createObjectUnderTest() {
return new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory,
return new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory,
pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier);
}

Expand All @@ -188,9 +188,16 @@ void testServiceStart() {
assertNotNull(kinesisService.getScheduler(buffer));
}

@Test
void testServiceThrowsWhenLeaseConfigIsInvalid() {
when(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig()).thenReturn(Optional.empty());
assertThrows(IllegalStateException.class, () -> new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory,
pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier));
}

@Test
void testCreateScheduler() {
KinesisService kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory,
KinesisService kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory,
pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier);
Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer);

Expand All @@ -208,7 +215,7 @@ void testCreateScheduler() {
@Test
void testCreateSchedulerWithPollingStrategy() {
when(kinesisSourceConfig.getConsumerStrategy()).thenReturn(ConsumerStrategy.POLLING);
KinesisService kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory,
KinesisService kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory,
pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier);
Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer);

Expand Down
Loading

0 comments on commit 1ca0077

Please sign in to comment.