Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kafka Source - Cleanup and Enhancements for MSK #3029

Merged
merged 5 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data-prepper-plugins/kafka-plugins/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies {
implementation 'io.confluent:kafka-schema-registry-client:7.3.3'
implementation 'io.confluent:kafka-schema-registry:7.3.3:tests'
implementation 'software.amazon.msk:aws-msk-iam-auth:1.1.6'
implementation 'software.amazon.awssdk:sts:2.20.103'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: it's better to not set sdk version here but you can change this later

implementation 'software.amazon.awssdk:auth:2.20.103'
implementation 'software.amazon.awssdk:kafka:2.20.103'
testImplementation 'org.mockito:mockito-inline:4.1.0'
testImplementation 'org.yaml:snakeyaml:2.0'
testImplementation testLibs.spring.test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static class AwsMskConfig {
private String arn;

@JsonProperty("broker_connection_type")
private MskBrokerConnectionType brokerConnectionType;
private MskBrokerConnectionType brokerConnectionType = MskBrokerConnectionType.SINGLE_VPC;

public String getArn() {
return arn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jakarta.validation.constraints.Size;

import java.util.List;
import java.util.Objects;
import java.time.Duration;

/**
Expand All @@ -22,8 +23,6 @@ public class KafkaSourceConfig {
public static final Duration DEFAULT_ACKNOWLEDGEMENTS_TIMEOUT = Duration.ofSeconds(30);

@JsonProperty("bootstrap_servers")
@NotNull
@Size(min = 1, message = "Bootstrap servers can't be empty")
private List<String> bootStrapServers;

@JsonProperty("topics")
Expand Down Expand Up @@ -68,8 +67,11 @@ public void setTopics(List<TopicConfig> topics) {
this.topics = topics;
}

public List<String> getBootStrapServers() {
return bootStrapServers;
public String getBootStrapServers() {
if (Objects.nonNull(bootStrapServers)) {
return String.join(",", bootStrapServers);
}
return null;
}

public void setBootStrapServers(List<String> bootStrapServers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,44 @@ public class TopicConfig {
private String name;

@JsonProperty("group_id")
@Valid
@Size(min = 1, max = 255, message = "size of group id should be between 1 and 255")
private String groupId;

@JsonProperty("workers")
@Valid
@Size(min = 1, max = 200, message = "Number of worker threads should lies between 1 and 200")
private Integer workers = NUM_OF_WORKERS;

@JsonProperty("max_retry_attempts")
@Valid
@Size(min = 1, max = Integer.MAX_VALUE, message = " Max retry attempts should lies between 1 and Integer.MAX_VALUE")
private Integer maxRetryAttempts = MAX_RETRY_ATTEMPT;

@JsonProperty("max_retry_delay")
@Valid
@Size(min = 1)
private Duration maxRetryDelay = MAX_RETRY_DELAY;

@JsonProperty("auto_commit")
private Boolean autoCommit = false;

@JsonProperty("auto_commit_interval")
@Valid
@Size(min = 1)
private Duration autoCommitInterval = AUTOCOMMIT_INTERVAL;

@JsonProperty("session_timeout")
@Valid
@Size(min = 1)
private Duration sessionTimeOut = SESSION_TIMEOUT;

@JsonProperty("auto_offset_reset")
private String autoOffsetReset = AUTO_OFFSET_RESET;

@JsonProperty("group_name")
@Valid
@Size(min = 1, max = 255, message = "size of group name should be between 1 and 255")
private String groupName;

@JsonProperty("thread_waiting_time")
Expand All @@ -78,19 +87,23 @@ public class TopicConfig {
private Duration maxRecordFetchTime = MAX_RECORD_FETCH_TIME;

@JsonProperty("buffer_default_timeout")
@Valid
@Size(min = 1)
private Duration bufferDefaultTimeout = BUFFER_DEFAULT_TIMEOUT;

@JsonProperty("fetch_max_bytes")
@Valid
@Size(min = 1, max = 52428800)
private Long fetchMaxBytes = FETCH_MAX_BYTES;

@JsonProperty("fetch_max_wait")
@Valid
@Size(min = 1)
private Long fetchMaxWait = FETCH_MAX_WAIT;

@JsonProperty("fetch_min_bytes")
@Size(min = 1)
@Valid
private Long fetchMinBytes = FETCH_MIN_BYTES;

@JsonProperty("retry_backoff")
Expand All @@ -103,6 +116,7 @@ public class TopicConfig {
private Integer consumerMaxPollRecords = CONSUMER_MAX_POLL_RECORDS;

@JsonProperty("heart_beat_interval")
@Valid
@Size(min = 1)
private Duration heartBeatInterval= HEART_BEAT_INTERVAL_DURATION;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
import org.opensearch.dataprepper.plugins.kafka.consumer.KafkaSourceCustomConsumer;
import org.opensearch.dataprepper.plugins.kafka.util.KafkaSourceJsonDeserializer;
import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat;

import software.amazon.awssdk.services.kafka.KafkaClient;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
import software.amazon.awssdk.services.kafka.model.InternalServerErrorException;
import software.amazon.awssdk.services.kafka.model.ConflictException;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.regions.Region;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -44,8 +57,10 @@
import java.net.URISyntaxException;
import java.net.URL;
import java.util.Comparator;
import java.util.Objects;
import java.util.List;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand All @@ -60,11 +75,12 @@
@SuppressWarnings("deprecation")
@DataPrepperPlugin(name = "kafka", pluginType = Source.class, pluginConfigurationType = KafkaSourceConfig.class)
public class KafkaSource implements Source<Record<Event>> {
private static final String KAFKA_WORKER_THREAD_PROCESSING_ERRORS = "kafkaWorkerThreadProcessingErrors";
private static final int MAX_KAFKA_CLIENT_RETRIES = 10;
private static final Logger LOG = LoggerFactory.getLogger(KafkaSource.class);
private final KafkaSourceConfig sourceConfig;
private AtomicBoolean shutdownInProgress;
private ExecutorService executorService;
private static final String KAFKA_WORKER_THREAD_PROCESSING_ERRORS = "kafkaWorkerThreadProcessingErrors";
private final Counter kafkaWorkerThreadProcessingErrors;
private final PluginMetrics pluginMetrics;
private KafkaSourceCustomConsumer consumer;
Expand Down Expand Up @@ -154,13 +170,99 @@ private long calculateLongestThreadWaitingTime() {
orElse(1L);
}

private Properties getConsumerProperties(TopicConfig topicConfig) {
public String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) {
AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) {
String sessionName = "data-prepper-kafka-session"+UUID.randomUUID();
StsClient stsClient = StsClient.builder()
.region(Region.of(awsConfig.getRegion()))
.credentialsProvider(credentialsProvider)
.build();
credentialsProvider = StsAssumeRoleCredentialsProvider
.builder()
.stsClient(stsClient)
.refreshRequest(
AssumeRoleRequest
.builder()
.roleArn(awsConfig.getStsRoleArn())
.roleSessionName(sessionName)
.build()
).build();
} else {
throw new RuntimeException("Unknown AWS IAM auth mode");
}
final AwsConfig.AwsMskConfig awsMskConfig = awsConfig.getAwsMskConfig();
KafkaClient kafkaClient = KafkaClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(awsConfig.getRegion()))
.build();
final GetBootstrapBrokersRequest request =
GetBootstrapBrokersRequest
.builder()
.clusterArn(awsMskConfig.getArn())
.build();

int numRetries = 0;
boolean retryable;
GetBootstrapBrokersResponse result = null;
do {
retryable = false;
try {
result = kafkaClient.getBootstrapBrokers(request);
} catch (InternalServerErrorException | ConflictException e) {
retryable = true;
} catch (Exception e) {
break;
}
} while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
if (Objects.isNull(result)) {
LOG.info("Failed to get bootstrap server information from MSK, using user configured bootstrap servers");
return sourceConfig.getBootStrapServers();
}
switch (awsMskConfig.getBrokerConnectionType()) {
case PUBLIC:
return result.bootstrapBrokerStringPublicSaslIam();
case MULTI_VPC:
return result.bootstrapBrokerStringVpcConnectivitySaslIam();
default:
case SINGLE_VPC:
return result.bootstrapBrokerStringSaslIam();
}
}

private Properties getConsumerProperties(final TopicConfig topicConfig) {
Properties properties = new Properties();
AwsIamAuthConfig awsIamAuthConfig = null;
AwsConfig awsConfig = sourceConfig.getAwsConfig();
if (sourceConfig.getAuthConfig() != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can user configure awsAuth without sourceConfig.getAuthConfig()? or with non-SASL config. it may be better to create a separate validation class to cover all combinations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they could configure awsConfig without authConfig but it won't be used. On the other hand, if they configured authConfig but not awsConfig, I will have to print error and shutdown the pipeline. Let me add that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for AwsConfig being null when non-null config is expected is already there.

AuthConfig.SaslAuthConfig saslAuthConfig = sourceConfig.getAuthConfig().getSaslAuthConfig();
if (saslAuthConfig != null) {
awsIamAuthConfig = saslAuthConfig.getAwsIamAuthConfig();
if (awsIamAuthConfig != null) {
if (encryptionType == EncryptionType.PLAINTEXT) {
throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism");
}
setAwsIamAuthProperties(properties, awsIamAuthConfig, awsConfig);
} else if (saslAuthConfig.getOAuthConfig() != null) {
} else if (saslAuthConfig.getPlainTextAuthConfig() != null) {
setPlainTextAuthProperties(properties);
} else {
throw new RuntimeException("No SASL auth config specified");
}
}
}
properties.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG,
topicConfig.getAutoCommitInterval().toSecondsPart());
properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
topicConfig.getAutoOffsetReset());
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, sourceConfig.getBootStrapServers());
String bootstrapServers = sourceConfig.getBootStrapServers();
if (Objects.nonNull(awsIamAuthConfig)) {
bootstrapServers = getBootStrapServersForMsk(awsIamAuthConfig, awsConfig);
}
if (Objects.isNull(bootstrapServers) || bootstrapServers.isEmpty()) {
throw new RuntimeException("Bootstrap servers are not specified");
}
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
topicConfig.getAutoCommit());
properties.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG,
Expand All @@ -173,22 +275,6 @@ private Properties getConsumerProperties(TopicConfig topicConfig) {
schemaType = MessageFormat.PLAINTEXT.toString();
}
setPropertiesForSchemaType(properties, schemaType);
if (sourceConfig.getAuthConfig() != null) {
AuthConfig.SaslAuthConfig saslAuthConfig = sourceConfig.getAuthConfig().getSaslAuthConfig();
if (saslAuthConfig != null) {
if (saslAuthConfig.getPlainTextAuthConfig() != null) {
setPlainTextAuthProperties(properties);
} else if (saslAuthConfig.getAwsIamAuthConfig() != null) {
if (encryptionType == EncryptionType.PLAINTEXT) {
throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism");
}
setAwsIamAuthProperties(properties, saslAuthConfig.getAwsIamAuthConfig(), sourceConfig.getAwsConfig());
} else if (saslAuthConfig.getOAuthConfig() != null) {
} else {
throw new RuntimeException("No SASL auth config specified");
}
}
}
LOG.info("Starting consumer with the properties : {}", properties);
return properties;
}
Expand Down Expand Up @@ -229,7 +315,7 @@ private String getSchemaRegistryUrl() {
return sourceConfig.getSchemaConfig().getRegistryURL();
}

private void setAwsIamAuthProperties(Properties properties, AwsIamAuthConfig awsIamAuthConfig, AwsConfig awsConfig) {
private void setAwsIamAuthProperties(Properties properties, final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) {
if (awsConfig == null) {
throw new RuntimeException("AWS Config is not specified");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.time.Duration;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.opensearch.dataprepper.test.helper.ReflectivelySetField.setField;

Expand Down Expand Up @@ -61,13 +59,8 @@ void test_kafka_config_not_null() {
@Test
void test_bootStrapServers_not_null(){
assertThat(kafkaSourceConfig.getBootStrapServers(), notNullValue());
List<String> servers = kafkaSourceConfig.getBootStrapServers();
bootstrapServers = servers.stream().
flatMap(str -> Arrays.stream(str.split(","))).
map(String::trim).
collect(Collectors.toList());
assertThat(bootstrapServers.size(), equalTo(1));
assertThat(bootstrapServers, hasItem("127.0.0.1:9093"));
String bootstrapServers = kafkaSourceConfig.getBootStrapServers();
assertTrue(bootstrapServers.contains("127.0.0.1:9093"));
}

@Test
Expand All @@ -84,7 +77,7 @@ void test_setters() throws NoSuchFieldException, IllegalAccessException {
TopicConfig topicConfig = mock(TopicConfig.class);
kafkaSourceConfig.setTopics(Collections.singletonList(topicConfig));

assertEquals(Arrays.asList("127.0.0.1:9092"), kafkaSourceConfig.getBootStrapServers());
assertEquals("127.0.0.1:9092", kafkaSourceConfig.getBootStrapServers());
assertEquals(Collections.singletonList(topicConfig), kafkaSourceConfig.getTopics());
setField(KafkaSourceConfig.class, kafkaSourceConfig, "acknowledgementsEnabled", true);
Duration testTimeout = Duration.ofSeconds(10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.List;
import java.time.Duration;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -90,7 +89,7 @@ void setUp() throws Exception {
when(topic2.getAutoCommit()).thenReturn(false);
when(topic1.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(10));
when(topic2.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(10));
when(sourceConfig.getBootStrapServers()).thenReturn(List.of("http://localhost:1234"));
when(sourceConfig.getBootStrapServers()).thenReturn("http://localhost:1234");
when(sourceConfig.getTopics()).thenReturn(Arrays.asList(topic1, topic2));
}

Expand Down
Loading