Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix consumer synchronization. Fix consumer to use user-specified groupId #3100

Merged
merged 3 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public class KafkaSourceJsonTypeIT {
private String bootstrapServers;
private String testKey;
private String testTopic;
private String testGroup;

public KafkaSource createObjectUnderTest() {
return new KafkaSource(sourceConfig, pluginMetrics, acknowledgementSetManager, pipelineDescription);
Expand Down Expand Up @@ -112,7 +113,7 @@ public void setup() {
} catch (Exception e){}

testKey = RandomStringUtils.randomAlphabetic(5);
final String testGroup = "TestGroup_"+RandomStringUtils.randomAlphabetic(6);
testGroup = "TestGroup_"+RandomStringUtils.randomAlphabetic(6);
testTopic = "TestJsonTopic_"+RandomStringUtils.randomAlphabetic(5);
jsonTopic = mock(TopicConfig.class);
when(jsonTopic.getName()).thenReturn(testTopic);
Expand Down Expand Up @@ -337,6 +338,7 @@ public void TestJsonRecordsWithKafkaKeyModeAsMetadata() throws Exception {
Thread.sleep(1000);
}
kafkaSource.start(buffer);
assertThat(kafkaSource.getConsumer().groupMetadata().groupId(), equalTo(testGroup));
produceJsonRecords(bootstrapServers, topicName, numRecords);
int numRetries = 0;
while (numRetries++ < 10 && (receivedRecords.size() != numRecords)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@
* pipelines.yaml
*/
public class TopicConfig {
private static final String AUTO_COMMIT = "false";
private static final Duration DEFAULT_COMMIT_INTERVAL = Duration.ofSeconds(5);
private static final Duration DEFAULT_SESSION_TIMEOUT = Duration.ofSeconds(45);
private static final int MAX_RETRY_ATTEMPT = Integer.MAX_VALUE;
static final boolean DEFAULT_AUTO_COMMIT = false;
static final Duration DEFAULT_COMMIT_INTERVAL = Duration.ofSeconds(5);
static final Duration DEFAULT_SESSION_TIMEOUT = Duration.ofSeconds(45);
static final int DEFAULT_MAX_RETRY_ATTEMPT = Integer.MAX_VALUE;
static final String DEFAULT_AUTO_OFFSET_RESET = "latest";
static final Duration THREAD_WAITING_TIME = Duration.ofSeconds(5);
private static final Duration MAX_RECORD_FETCH_TIME = Duration.ofSeconds(4);
private static final Duration BUFFER_DEFAULT_TIMEOUT = Duration.ofSeconds(5);
private static final Duration MAX_RETRY_DELAY = Duration.ofSeconds(1);
private static final Integer FETCH_MAX_BYTES = 52428800;
private static final Integer FETCH_MAX_WAIT = 500;
private static final Integer FETCH_MIN_BYTES = 1;
private static final Duration RETRY_BACKOFF = Duration.ofSeconds(100);
private static final Duration MAX_POLL_INTERVAL = Duration.ofSeconds(300000);
private static final Integer CONSUMER_MAX_POLL_RECORDS = 500;
static final Duration DEFAULT_THREAD_WAITING_TIME = Duration.ofSeconds(5);
static final Duration DEFAULT_MAX_RECORD_FETCH_TIME = Duration.ofSeconds(4);
static final Duration DEFAULT_BUFFER_TIMEOUT = Duration.ofSeconds(5);
static final Duration DEFAULT_MAX_RETRY_DELAY = Duration.ofSeconds(1);
static final Integer DEFAULT_FETCH_MAX_BYTES = 52428800;
static final Integer DEFAULT_FETCH_MAX_WAIT = 500;
static final Integer DEFAULT_FETCH_MIN_BYTES = 1;
static final Duration DEFAULT_RETRY_BACKOFF = Duration.ofSeconds(10);
static final Duration DEFAULT_RECONNECT_BACKOFF = Duration.ofSeconds(10);
static final Duration DEFAULT_MAX_POLL_INTERVAL = Duration.ofSeconds(300000);
static final Integer DEFAULT_CONSUMER_MAX_POLL_RECORDS = 500;
static final Integer DEFAULT_NUM_OF_WORKERS = 2;
static final Duration HEART_BEAT_INTERVAL_DURATION = Duration.ofSeconds(5);
static final Duration DEFAULT_HEART_BEAT_INTERVAL_DURATION = Duration.ofSeconds(5);

@JsonProperty("name")
@NotNull
Expand All @@ -54,18 +55,18 @@ public class TopicConfig {
@JsonProperty("max_retry_attempts")
@Valid
@Size(min = 1, max = Integer.MAX_VALUE, message = " Max retry attempts should lies between 1 and Integer.MAX_VALUE")
private Integer maxRetryAttempts = MAX_RETRY_ATTEMPT;
private Integer maxRetryAttempts = DEFAULT_MAX_RETRY_ATTEMPT;

@JsonProperty("max_retry_delay")
@Valid
@Size(min = 1)
private Duration maxRetryDelay = MAX_RETRY_DELAY;
private Duration maxRetryDelay = DEFAULT_MAX_RETRY_DELAY;

@JsonProperty("serde_format")
private MessageFormat serdeFormat= MessageFormat.PLAINTEXT;

@JsonProperty("auto_commit")
private Boolean autoCommit = false;
private Boolean autoCommit = DEFAULT_AUTO_COMMIT;

@JsonProperty("commit_interval")
@Valid
Expand All @@ -86,47 +87,50 @@ public class TopicConfig {
private String groupName;

@JsonProperty("thread_waiting_time")
private Duration threadWaitingTime = THREAD_WAITING_TIME;
private Duration threadWaitingTime = DEFAULT_THREAD_WAITING_TIME;

@JsonProperty("max_record_fetch_time")
private Duration maxRecordFetchTime = MAX_RECORD_FETCH_TIME;
private Duration maxRecordFetchTime = DEFAULT_MAX_RECORD_FETCH_TIME;

@JsonProperty("buffer_default_timeout")
@Valid
@Size(min = 1)
private Duration bufferDefaultTimeout = BUFFER_DEFAULT_TIMEOUT;
private Duration bufferDefaultTimeout = DEFAULT_BUFFER_TIMEOUT;

@JsonProperty("fetch_max_bytes")
@Valid
@Size(min = 1, max = 52428800)
private Integer fetchMaxBytes = FETCH_MAX_BYTES;
private Integer fetchMaxBytes = DEFAULT_FETCH_MAX_BYTES;

@JsonProperty("fetch_max_wait")
@Valid
@Size(min = 1)
private Integer fetchMaxWait = FETCH_MAX_WAIT;
private Integer fetchMaxWait = DEFAULT_FETCH_MAX_WAIT;

@JsonProperty("fetch_min_bytes")
@Size(min = 1)
@Valid
private Integer fetchMinBytes = FETCH_MIN_BYTES;
private Integer fetchMinBytes = DEFAULT_FETCH_MIN_BYTES;

@JsonProperty("key_mode")
private KafkaKeyMode kafkaKeyMode = KafkaKeyMode.INCLUDE_AS_FIELD;

@JsonProperty("retry_backoff")
private Duration retryBackoff = RETRY_BACKOFF;
private Duration retryBackoff = DEFAULT_RETRY_BACKOFF;

@JsonProperty("reconnect_backoff")
private Duration reconnectBackoff = DEFAULT_RECONNECT_BACKOFF;

@JsonProperty("max_poll_interval")
private Duration maxPollInterval = MAX_POLL_INTERVAL;
private Duration maxPollInterval = DEFAULT_MAX_POLL_INTERVAL;

@JsonProperty("consumer_max_poll_records")
private Integer consumerMaxPollRecords = CONSUMER_MAX_POLL_RECORDS;
private Integer consumerMaxPollRecords = DEFAULT_CONSUMER_MAX_POLL_RECORDS;

@JsonProperty("heart_beat_interval")
@Valid
@Size(min = 1)
private Duration heartBeatInterval= HEART_BEAT_INTERVAL_DURATION;
private Duration heartBeatInterval= DEFAULT_HEART_BEAT_INTERVAL_DURATION;

public String getGroupId() {
return groupId;
Expand Down Expand Up @@ -220,6 +224,10 @@ public Duration getRetryBackoff() {
return retryBackoff;
}

public Duration getReconnectBackoff() {
return reconnectBackoff;
}

public void setRetryBackoff(Duration retryBackoff) {
this.retryBackoff = retryBackoff;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.CommitFailedException;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.RecordDeserializationException;
import org.apache.kafka.common.TopicPartition;
import org.apache.avro.generic.GenericRecord;
import org.opensearch.dataprepper.model.log.JacksonLog;
Expand All @@ -40,6 +41,8 @@
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean;
import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat;
import com.amazonaws.services.schemaregistry.serializers.json.JsonDataWithSchema;
Expand Down Expand Up @@ -68,6 +71,7 @@ public class KafkaSourceCustomConsumer implements Runnable, ConsumerRebalanceLis
private static final ObjectMapper objectMapper = new ObjectMapper();
private final JsonFactory jsonFactory = new JsonFactory();
private Map<TopicPartition, OffsetAndMetadata> offsetsToCommit;
private Set<TopicPartition> partitionsToReset;
private final AcknowledgementSetManager acknowledgementSetManager;
private final Map<Integer, TopicPartitionCommitTracker> partitionCommitTrackerMap;
private final Counter positiveAcknowledgementSetCounter;
Expand Down Expand Up @@ -95,6 +99,7 @@ public KafkaSourceCustomConsumer(final KafkaConsumer consumer,
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginMetrics = pluginMetrics;
this.partitionCommitTrackerMap = new HashMap<>();
this.partitionsToReset = new HashSet<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use a concurrent set to avoid synchronization on this object.

this.partitionsToReset = Collections.synchronizedSet(new HashSet<>());

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better performance wise?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it, but it would ensure that somebody doesn't forget to synchronize calls as the file is maintained.

this.schema = MessageFormat.getByMessageFormatByName(schemaType);
Duration bufferTimeout = Duration.ofSeconds(1);
this.bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, bufferTimeout);
Expand All @@ -121,29 +126,21 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, Range<Lo
try {
int partitionId = partition.partition();
if (!partitionCommitTrackerMap.containsKey(partitionId)) {
OffsetAndMetadata committedOffsetAndMetadata = null;
synchronized(consumer) {
committedOffsetAndMetadata = consumer.committed(partition);
}
OffsetAndMetadata committedOffsetAndMetadata = consumer.committed(partition);
Long committedOffset = Objects.nonNull(committedOffsetAndMetadata) ? committedOffsetAndMetadata.offset() : null;
partitionCommitTrackerMap.put(partitionId, new TopicPartitionCommitTracker(partition, committedOffset));
}
OffsetAndMetadata offsetAndMetadata = partitionCommitTrackerMap.get(partitionId).addCompletedOffsets(offsetRange);
updateOffsetsToCommit(partition, offsetAndMetadata);
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon positive acknowledgement "+partition, e);
LOG.error("Failed to seek to last committed offset upon positive acknowledgement {}", partition, e);
}
});
} else {
negativeAcknowledgementSetCounter.increment();
offsets.forEach((partition, offsetRange) -> {
try {
synchronized(consumer) {
OffsetAndMetadata committedOffsetAndMetadata = consumer.committed(partition);
consumer.seek(partition, committedOffsetAndMetadata);
}
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon negative acknowledgement "+partition, e);
synchronized(partitionsToReset) {
partitionsToReset.add(partition);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a nice change. Did you do any performance testing or find any particular pitfalls with the original approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't done any performance testing. But definitely having a lock every time consumer is accessed is not good.

}
});
}
Expand All @@ -157,10 +154,7 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, Range<Lo

public <T> void consumeRecords() throws Exception {
try {
ConsumerRecords<String, T> records = null;
synchronized(consumer) {
records = consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2);
}
ConsumerRecords<String, T> records = consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2);
if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) {
Map<TopicPartition, Range<Long>> offsets = new HashMap<>();
AcknowledgementSet acknowledgementSet = null;
Expand All @@ -176,12 +170,27 @@ public <T> void consumeRecords() throws Exception {
}
}
} catch (AuthenticationException e) {
LOG.warn("Authentication Error while doing poll(). Will retry after 10 seconds", e);
LOG.warn("Access Denied while doing poll(). Will retry after 10 seconds", e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Apache documentation indicates that this is an authentication error and not necessarily related to access controls.

Thread.sleep(10000);
} catch (RecordDeserializationException e) {
LOG.warn("Serialization error - topic {} partition {} offset {}, seeking past the error record",
Copy link
Contributor

@hshardeesi hshardeesi Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Serialization -> Deserialization.

Also increment a metric when we get to metrics.

e.topicPartition().topic(), e.topicPartition().partition(), e.offset());
consumer.seek(e.topicPartition(), e.offset()+1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really required to explicitly seek past this record? can we not just log exception, count a metric and commit offset as usual?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the documentation, we have to seek past the offset to continue reading. We cannot commit the offset unless we get acknowledgement that previously read records are flushed to the sink, right?

}
}

private void commitOffsets() {
private void resetOrCommitOffsets() {
synchronized(partitionsToReset) {
partitionsToReset.forEach(partition -> {
try {
final OffsetAndMetadata offsetAndMetadata = consumer.committed(partition);
consumer.seek(partition, offsetAndMetadata);
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon negative acknowledgement {}", partition, e);
}
});
partitionsToReset.clear();
}
if (topicConfig.getAutoCommit()) {
return;
}
Expand All @@ -194,13 +203,11 @@ private void commitOffsets() {
return;
}
try {
synchronized(consumer) {
consumer.commitSync();
}
consumer.commitSync();
offsetsToCommit.clear();
lastCommitTime = currentTimeMillis;
} catch (CommitFailedException e) {
LOG.error("Failed to commit offsets in topic "+topicName, e);
LOG.error("Failed to commit offsets in topic {}", topicName, e);
}
}
}
Expand All @@ -211,14 +218,14 @@ Map<TopicPartition, OffsetAndMetadata> getOffsetsToCommit() {

@Override
public void run() {
try {
consumer.subscribe(Arrays.asList(topicName));
while (!shutdownInProgress.get()) {
consumer.subscribe(Arrays.asList(topicName));
while (!shutdownInProgress.get()) {
try {
resetOrCommitOffsets();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might be better to have separate functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it. Resetting is just 3 or 4 line code and didn't feel like making it a separate function. Especially it may be null operation most of the time.

consumeRecords();
commitOffsets();
} catch (Exception exp) {
LOG.error("Error while reading the records from the topic...", exp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid ellipsis in our logs. It's not a big deal, but it seems either indicate that we trailed of or have more to say (not the exception).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove ellipsis in the next PR.

}
} catch (Exception exp) {
LOG.error("Error while reading the records from the topic...", exp);
}
}

Expand Down Expand Up @@ -306,9 +313,8 @@ public void shutdownConsumer(){
@Override
public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
for (TopicPartition topicPartition : partitions) {
synchronized(consumer) {
Long committedOffset = consumer.committed(topicPartition).offset();
consumer.seek(topicPartition, committedOffset);
synchronized(partitionsToReset) {
partitionsToReset.add(topicPartition);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class KafkaSource implements Source<Record<Event>> {
private final Counter kafkaWorkerThreadProcessingErrors;
private final PluginMetrics pluginMetrics;
private KafkaSourceCustomConsumer consumer;
private KafkaConsumer kafkaConsumer;
private String pipelineName;
private String consumerGroupID;
private String schemaType = MessageFormat.PLAINTEXT.toString();
Expand Down Expand Up @@ -125,7 +126,6 @@ public void start(Buffer<Record<Event>> buffer) {
int numWorkers = topic.getWorkers();
executorService = Executors.newFixedThreadPool(numWorkers);
IntStream.range(0, numWorkers + 1).forEach(index -> {
KafkaConsumer kafkaConsumer;
switch (schema) {
case JSON:
kafkaConsumer = new KafkaConsumer<String, JsonNode>(consumerProperties);
Expand Down Expand Up @@ -185,6 +185,9 @@ private long calculateLongestThreadWaitingTime() {
orElse(1L);
}

KafkaConsumer getConsumer() {
return kafkaConsumer;
}

private Properties getConsumerProperties(final TopicConfig topicConfig) {
Properties properties = new Properties();
Expand Down Expand Up @@ -361,6 +364,8 @@ private void setPropertiesForSchemaType(Properties properties, TopicConfig topic

private void setConsumerTopicProperties(Properties properties, TopicConfig topicConfig) {
properties.put(ConsumerConfig.GROUP_ID_CONFIG, consumerGroupID);
properties.put(ConsumerConfig.RETRY_BACKOFF_MS_CONFIG, ((Long)topicConfig.getRetryBackoff().toMillis()).intValue());
properties.put(ConsumerConfig.RECONNECT_BACKOFF_MS_CONFIG, ((Long)topicConfig.getReconnectBackoff().toMillis()).intValue());
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
topicConfig.getAutoCommit());
properties.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import software.amazon.awssdk.services.kafka.KafkaClient;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
import software.amazon.awssdk.services.kafka.model.InternalServerErrorException;
import software.amazon.awssdk.services.kafka.model.ConflictException;
import software.amazon.awssdk.services.kafka.model.ForbiddenException;
import software.amazon.awssdk.services.kafka.model.UnauthorizedException;
import software.amazon.awssdk.services.kafka.model.KafkaException;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.services.sts.model.StsException;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
Expand Down Expand Up @@ -214,17 +211,15 @@ public static String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuth
retryable = false;
try {
result = kafkaClient.getBootstrapBrokers(request);
} catch (InternalServerErrorException | ConflictException | ForbiddenException | UnauthorizedException | StsException e) {
} catch (KafkaException | StsException e) {
LOG.debug("Failed to get bootstrap server information from MSK. Retrying...", e);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be LOG.info? also make it explicit like will retry with exponential backoff or after so many seconds. Do we need to log entire backtrace? just e.message() may be enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sure every time we included just the message, entire stack trace was later needed. Hope fully, these are not common scenarios.


retryable = true;
try {
Thread.sleep(10000);
} catch (InterruptedException exp) {}
} catch (Exception e) {
throw new RuntimeException("Failed to get bootstrap server information from MSK.", e);
}
} while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
} while (numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
if (Objects.isNull(result)) {
throw new RuntimeException("Failed to get bootstrap server information from MSK after trying multiple times with retryable exceptions.");
}
Expand Down
Loading
Loading