Skip to content

Commit

Permalink
Kafka Source - Cleanup and Enhancements for MSK (#3029)
Browse files Browse the repository at this point in the history
* Kafka Source - Cleanup and Enhancements for MSK

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

* addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

* Fixed checkstyle error

Signed-off-by: Krishna Kondaka <[email protected]>

---------

Signed-off-by: Krishna Kondaka <[email protected]>
Co-authored-by: Krishna Kondaka <[email protected]>
  • Loading branch information
kkondaka and Krishna Kondaka committed Jul 18, 2023
1 parent 1edf97c commit 351845b
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 38 deletions.
3 changes: 3 additions & 0 deletions data-prepper-plugins/kafka-plugins/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies {
implementation 'io.confluent:kafka-schema-registry-client:7.3.3'
implementation 'io.confluent:kafka-schema-registry:7.3.3:tests'
implementation 'software.amazon.msk:aws-msk-iam-auth:1.1.6'
implementation 'software.amazon.awssdk:sts:2.20.103'
implementation 'software.amazon.awssdk:auth:2.20.103'
implementation 'software.amazon.awssdk:kafka:2.20.103'
testImplementation 'org.mockito:mockito-inline:4.1.0'
testImplementation 'org.yaml:snakeyaml:2.0'
testImplementation testLibs.spring.test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static class AwsMskConfig {
private String arn;

@JsonProperty("broker_connection_type")
private MskBrokerConnectionType brokerConnectionType;
private MskBrokerConnectionType brokerConnectionType = MskBrokerConnectionType.SINGLE_VPC;

public String getArn() {
return arn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jakarta.validation.constraints.Size;

import java.util.List;
import java.util.Objects;
import java.time.Duration;

/**
Expand All @@ -22,8 +23,6 @@ public class KafkaSourceConfig {
public static final Duration DEFAULT_ACKNOWLEDGEMENTS_TIMEOUT = Duration.ofSeconds(30);

@JsonProperty("bootstrap_servers")
@NotNull
@Size(min = 1, message = "Bootstrap servers can't be empty")
private List<String> bootStrapServers;

@JsonProperty("topics")
Expand Down Expand Up @@ -68,8 +67,11 @@ public void setTopics(List<TopicConfig> topics) {
this.topics = topics;
}

public List<String> getBootStrapServers() {
return bootStrapServers;
public String getBootStrapServers() {
if (Objects.nonNull(bootStrapServers)) {
return String.join(",", bootStrapServers);
}
return null;
}

public void setBootStrapServers(List<String> bootStrapServers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,44 @@ public class TopicConfig {
private String name;

@JsonProperty("group_id")
@Valid
@Size(min = 1, max = 255, message = "size of group id should be between 1 and 255")
private String groupId;

@JsonProperty("workers")
@Valid
@Size(min = 1, max = 200, message = "Number of worker threads should lies between 1 and 200")
private Integer workers = NUM_OF_WORKERS;

@JsonProperty("max_retry_attempts")
@Valid
@Size(min = 1, max = Integer.MAX_VALUE, message = " Max retry attempts should lies between 1 and Integer.MAX_VALUE")
private Integer maxRetryAttempts = MAX_RETRY_ATTEMPT;

@JsonProperty("max_retry_delay")
@Valid
@Size(min = 1)
private Duration maxRetryDelay = MAX_RETRY_DELAY;

@JsonProperty("auto_commit")
private Boolean autoCommit = false;

@JsonProperty("auto_commit_interval")
@Valid
@Size(min = 1)
private Duration autoCommitInterval = AUTOCOMMIT_INTERVAL;

@JsonProperty("session_timeout")
@Valid
@Size(min = 1)
private Duration sessionTimeOut = SESSION_TIMEOUT;

@JsonProperty("auto_offset_reset")
private String autoOffsetReset = AUTO_OFFSET_RESET;

@JsonProperty("group_name")
@Valid
@Size(min = 1, max = 255, message = "size of group name should be between 1 and 255")
private String groupName;

@JsonProperty("thread_waiting_time")
Expand All @@ -78,19 +87,23 @@ public class TopicConfig {
private Duration maxRecordFetchTime = MAX_RECORD_FETCH_TIME;

@JsonProperty("buffer_default_timeout")
@Valid
@Size(min = 1)
private Duration bufferDefaultTimeout = BUFFER_DEFAULT_TIMEOUT;

@JsonProperty("fetch_max_bytes")
@Valid
@Size(min = 1, max = 52428800)
private Long fetchMaxBytes = FETCH_MAX_BYTES;

@JsonProperty("fetch_max_wait")
@Valid
@Size(min = 1)
private Long fetchMaxWait = FETCH_MAX_WAIT;

@JsonProperty("fetch_min_bytes")
@Size(min = 1)
@Valid
private Long fetchMinBytes = FETCH_MIN_BYTES;

@JsonProperty("retry_backoff")
Expand All @@ -103,6 +116,7 @@ public class TopicConfig {
private Integer consumerMaxPollRecords = CONSUMER_MAX_POLL_RECORDS;

@JsonProperty("heart_beat_interval")
@Valid
@Size(min = 1)
private Duration heartBeatInterval= HEART_BEAT_INTERVAL_DURATION;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
import org.opensearch.dataprepper.plugins.kafka.consumer.KafkaSourceCustomConsumer;
import org.opensearch.dataprepper.plugins.kafka.util.KafkaSourceJsonDeserializer;
import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat;

import software.amazon.awssdk.services.kafka.KafkaClient;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
import software.amazon.awssdk.services.kafka.model.InternalServerErrorException;
import software.amazon.awssdk.services.kafka.model.ConflictException;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.regions.Region;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -44,8 +57,10 @@
import java.net.URISyntaxException;
import java.net.URL;
import java.util.Comparator;
import java.util.Objects;
import java.util.List;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand All @@ -60,11 +75,12 @@
@SuppressWarnings("deprecation")
@DataPrepperPlugin(name = "kafka", pluginType = Source.class, pluginConfigurationType = KafkaSourceConfig.class)
public class KafkaSource implements Source<Record<Event>> {
private static final String KAFKA_WORKER_THREAD_PROCESSING_ERRORS = "kafkaWorkerThreadProcessingErrors";
private static final int MAX_KAFKA_CLIENT_RETRIES = 10;
private static final Logger LOG = LoggerFactory.getLogger(KafkaSource.class);
private final KafkaSourceConfig sourceConfig;
private AtomicBoolean shutdownInProgress;
private ExecutorService executorService;
private static final String KAFKA_WORKER_THREAD_PROCESSING_ERRORS = "kafkaWorkerThreadProcessingErrors";
private final Counter kafkaWorkerThreadProcessingErrors;
private final PluginMetrics pluginMetrics;
private KafkaSourceCustomConsumer consumer;
Expand Down Expand Up @@ -154,13 +170,99 @@ private long calculateLongestThreadWaitingTime() {
orElse(1L);
}

private Properties getConsumerProperties(TopicConfig topicConfig) {
public String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) {
AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) {
String sessionName = "data-prepper-kafka-session"+UUID.randomUUID();
StsClient stsClient = StsClient.builder()
.region(Region.of(awsConfig.getRegion()))
.credentialsProvider(credentialsProvider)
.build();
credentialsProvider = StsAssumeRoleCredentialsProvider
.builder()
.stsClient(stsClient)
.refreshRequest(
AssumeRoleRequest
.builder()
.roleArn(awsConfig.getStsRoleArn())
.roleSessionName(sessionName)
.build()
).build();
} else {
throw new RuntimeException("Unknown AWS IAM auth mode");
}
final AwsConfig.AwsMskConfig awsMskConfig = awsConfig.getAwsMskConfig();
KafkaClient kafkaClient = KafkaClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(awsConfig.getRegion()))
.build();
final GetBootstrapBrokersRequest request =
GetBootstrapBrokersRequest
.builder()
.clusterArn(awsMskConfig.getArn())
.build();

int numRetries = 0;
boolean retryable;
GetBootstrapBrokersResponse result = null;
do {
retryable = false;
try {
result = kafkaClient.getBootstrapBrokers(request);
} catch (InternalServerErrorException | ConflictException e) {
retryable = true;
} catch (Exception e) {
break;
}
} while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
if (Objects.isNull(result)) {
LOG.info("Failed to get bootstrap server information from MSK, using user configured bootstrap servers");
return sourceConfig.getBootStrapServers();
}
switch (awsMskConfig.getBrokerConnectionType()) {
case PUBLIC:
return result.bootstrapBrokerStringPublicSaslIam();
case MULTI_VPC:
return result.bootstrapBrokerStringVpcConnectivitySaslIam();
default:
case SINGLE_VPC:
return result.bootstrapBrokerStringSaslIam();
}
}

private Properties getConsumerProperties(final TopicConfig topicConfig) {
Properties properties = new Properties();
AwsIamAuthConfig awsIamAuthConfig = null;
AwsConfig awsConfig = sourceConfig.getAwsConfig();
if (sourceConfig.getAuthConfig() != null) {
AuthConfig.SaslAuthConfig saslAuthConfig = sourceConfig.getAuthConfig().getSaslAuthConfig();
if (saslAuthConfig != null) {
awsIamAuthConfig = saslAuthConfig.getAwsIamAuthConfig();
if (awsIamAuthConfig != null) {
if (encryptionType == EncryptionType.PLAINTEXT) {
throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism");
}
setAwsIamAuthProperties(properties, awsIamAuthConfig, awsConfig);
} else if (saslAuthConfig.getOAuthConfig() != null) {
} else if (saslAuthConfig.getPlainTextAuthConfig() != null) {
setPlainTextAuthProperties(properties);
} else {
throw new RuntimeException("No SASL auth config specified");
}
}
}
properties.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG,
topicConfig.getAutoCommitInterval().toSecondsPart());
properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
topicConfig.getAutoOffsetReset());
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, sourceConfig.getBootStrapServers());
String bootstrapServers = sourceConfig.getBootStrapServers();
if (Objects.nonNull(awsIamAuthConfig)) {
bootstrapServers = getBootStrapServersForMsk(awsIamAuthConfig, awsConfig);
}
if (Objects.isNull(bootstrapServers) || bootstrapServers.isEmpty()) {
throw new RuntimeException("Bootstrap servers are not specified");
}
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
topicConfig.getAutoCommit());
properties.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG,
Expand All @@ -173,22 +275,6 @@ private Properties getConsumerProperties(TopicConfig topicConfig) {
schemaType = MessageFormat.PLAINTEXT.toString();
}
setPropertiesForSchemaType(properties, schemaType);
if (sourceConfig.getAuthConfig() != null) {
AuthConfig.SaslAuthConfig saslAuthConfig = sourceConfig.getAuthConfig().getSaslAuthConfig();
if (saslAuthConfig != null) {
if (saslAuthConfig.getPlainTextAuthConfig() != null) {
setPlainTextAuthProperties(properties);
} else if (saslAuthConfig.getAwsIamAuthConfig() != null) {
if (encryptionType == EncryptionType.PLAINTEXT) {
throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism");
}
setAwsIamAuthProperties(properties, saslAuthConfig.getAwsIamAuthConfig(), sourceConfig.getAwsConfig());
} else if (saslAuthConfig.getOAuthConfig() != null) {
} else {
throw new RuntimeException("No SASL auth config specified");
}
}
}
LOG.info("Starting consumer with the properties : {}", properties);
return properties;
}
Expand Down Expand Up @@ -229,7 +315,7 @@ private String getSchemaRegistryUrl() {
return sourceConfig.getSchemaConfig().getRegistryURL();
}

private void setAwsIamAuthProperties(Properties properties, AwsIamAuthConfig awsIamAuthConfig, AwsConfig awsConfig) {
private void setAwsIamAuthProperties(Properties properties, final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) {
if (awsConfig == null) {
throw new RuntimeException("AWS Config is not specified");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.time.Duration;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.opensearch.dataprepper.test.helper.ReflectivelySetField.setField;

Expand Down Expand Up @@ -61,13 +59,8 @@ void test_kafka_config_not_null() {
@Test
void test_bootStrapServers_not_null(){
assertThat(kafkaSourceConfig.getBootStrapServers(), notNullValue());
List<String> servers = kafkaSourceConfig.getBootStrapServers();
bootstrapServers = servers.stream().
flatMap(str -> Arrays.stream(str.split(","))).
map(String::trim).
collect(Collectors.toList());
assertThat(bootstrapServers.size(), equalTo(1));
assertThat(bootstrapServers, hasItem("127.0.0.1:9093"));
String bootstrapServers = kafkaSourceConfig.getBootStrapServers();
assertTrue(bootstrapServers.contains("127.0.0.1:9093"));
}

@Test
Expand All @@ -84,7 +77,7 @@ void test_setters() throws NoSuchFieldException, IllegalAccessException {
TopicConfig topicConfig = mock(TopicConfig.class);
kafkaSourceConfig.setTopics(Collections.singletonList(topicConfig));

assertEquals(Arrays.asList("127.0.0.1:9092"), kafkaSourceConfig.getBootStrapServers());
assertEquals("127.0.0.1:9092", kafkaSourceConfig.getBootStrapServers());
assertEquals(Collections.singletonList(topicConfig), kafkaSourceConfig.getTopics());
setField(KafkaSourceConfig.class, kafkaSourceConfig, "acknowledgementsEnabled", true);
Duration testTimeout = Duration.ofSeconds(10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.List;
import java.time.Duration;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -90,7 +89,7 @@ void setUp() throws Exception {
when(topic2.getAutoCommit()).thenReturn(false);
when(topic1.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(10));
when(topic2.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(10));
when(sourceConfig.getBootStrapServers()).thenReturn(List.of("http://localhost:1234"));
when(sourceConfig.getBootStrapServers()).thenReturn("http://localhost:1234");
when(sourceConfig.getTopics()).thenReturn(Arrays.asList(topic1, topic2));
}

Expand Down

0 comments on commit 351845b

Please sign in to comment.