Skip to content

Commit

Permalink
[#30941]fix upgrade test due to missed config ConsumerPollingTimeout (#…
Browse files Browse the repository at this point in the history
…30998)

* [#30941]fix upgrade test deu to  missed config ConsumerPollingTimeout in  KafkaIOTranslation

* [#30941]fix upgrade test due to  missed config ConsumerPollingTimeout in KafkaIOTranslation

* [#30941]fix upgrade test due to  missed config ConsumerPollingTimeout in KafkaIOTranslation

* [#30941]fix upgrade test due to  missed config ConsumerPollingTimeout in KafkaIOTranslation

* [#30941]fix upgrade test due to  missed config ConsumerPollingTimeout in KafkaIOTranslation

* fixed upgrade test and changed consumer timeout to long

* fixed spotless issues

* fixed test
  • Loading branch information
xianhualiu authored Apr 17, 2024
1 parent 6366bd4 commit 8092932
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ public static <K, V> Read<K, V> read() {
.setCommitOffsetsInFinalizeEnabled(false)
.setDynamicRead(false)
.setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
.setConsumerPollingTimeout(Duration.standardSeconds(2L))
.setConsumerPollingTimeout(2L)
.build();
}

Expand Down Expand Up @@ -708,7 +708,7 @@ public abstract static class Read<K, V>
public abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();

@Pure
public abstract @Nullable Duration getConsumerPollingTimeout();
public abstract long getConsumerPollingTimeout();

abstract Builder<K, V> toBuilder();

Expand Down Expand Up @@ -766,7 +766,7 @@ Builder<K, V> setCheckStopReadingFn(
return setCheckStopReadingFn(CheckStopReadingFnWrapper.of(checkStopReadingFn));
}

abstract Builder<K, V> setConsumerPollingTimeout(Duration consumerPollingTimeout);
abstract Builder<K, V> setConsumerPollingTimeout(long consumerPollingTimeout);

abstract Read<K, V> build();

Expand Down Expand Up @@ -836,10 +836,9 @@ static <K, V> void setupExternalBuilder(
if (config.consumerPollingTimeout <= 0) {
throw new IllegalArgumentException("consumerPollingTimeout should be > 0.");
}
builder.setConsumerPollingTimeout(
Duration.standardSeconds(config.consumerPollingTimeout));
builder.setConsumerPollingTimeout(config.consumerPollingTimeout);
} else {
builder.setConsumerPollingTimeout(Duration.standardSeconds(2L));
builder.setConsumerPollingTimeout(2L);
}
}

Expand Down Expand Up @@ -1356,14 +1355,12 @@ public Read<K, V> withBadRecordErrorHandler(ErrorHandler<BadRecord, ?> badRecord
}

/**
* Sets the timeout time for Kafka consumer polling request in the {@link ReadFromKafkaDoFn}. A
* lower timeout optimizes for latency. Increase the timeout if the consumer is not fetching
* enough (or any) records. The default is 2 seconds.
* Sets the timeout time in seconds for Kafka consumer polling request in the {@link
* ReadFromKafkaDoFn}. A lower timeout optimizes for latency. Increase the timeout if the
* consumer is not fetching any records. The default is 2 seconds.
*/
public Read<K, V> withConsumerPollingTimeout(Duration duration) {
checkState(
duration == null || duration.compareTo(Duration.ZERO) > 0,
"Consumer polling timeout must be greater than 0.");
public Read<K, V> withConsumerPollingTimeout(long duration) {
checkState(duration > 0, "Consumer polling timeout must be greater than 0.");
return toBuilder().setConsumerPollingTimeout(duration).build();
}

Expand Down Expand Up @@ -2071,7 +2068,7 @@ public abstract static class ReadSourceDescriptors<K, V>
abstract ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();

@Pure
abstract @Nullable Duration getConsumerPollingTimeout();
abstract long getConsumerPollingTimeout();

abstract boolean isBounded();

Expand Down Expand Up @@ -2123,8 +2120,7 @@ abstract ReadSourceDescriptors.Builder<K, V> setBadRecordRouter(
abstract ReadSourceDescriptors.Builder<K, V> setBadRecordErrorHandler(
ErrorHandler<BadRecord, ?> badRecordErrorHandler);

abstract ReadSourceDescriptors.Builder<K, V> setConsumerPollingTimeout(
@Nullable Duration duration);
abstract ReadSourceDescriptors.Builder<K, V> setConsumerPollingTimeout(long duration);

abstract ReadSourceDescriptors.Builder<K, V> setBounded(boolean bounded);

Expand All @@ -2139,7 +2135,7 @@ public static <K, V> ReadSourceDescriptors<K, V> read() {
.setBounded(false)
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
.setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>())
.setConsumerPollingTimeout(Duration.standardSeconds(2L))
.setConsumerPollingTimeout(2L)
.build()
.withProcessingTime()
.withMonotonicallyIncreasingWatermarkEstimator();
Expand Down Expand Up @@ -2402,11 +2398,11 @@ public ReadSourceDescriptors<K, V> withBadRecordErrorHandler(
}

/**
* Sets the timeout time for Kafka consumer polling request in the {@link ReadFromKafkaDoFn}. A
* lower timeout optimizes for latency. Increase the timeout if the consumer is not fetching
* enough (or any) records. The default is 2 seconds.
* Sets the timeout time in seconds for Kafka consumer polling request in the {@link
* ReadFromKafkaDoFn}. A lower timeout optimizes for latency. Increase the timeout if the
* consumer is not fetching any records. The default is 2 seconds.
*/
public ReadSourceDescriptors<K, V> withConsumerPollingTimeout(@Nullable Duration duration) {
public ReadSourceDescriptors<K, V> withConsumerPollingTimeout(long duration) {
return toBuilder().setConsumerPollingTimeout(duration).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ Object getDefaultValue() {
VALUE_DESERIALIZER_PROVIDER,
CHECK_STOP_READING_FN(SDF),
BAD_RECORD_ERROR_HANDLER(SDF),
CONSUMER_POLLING_TIMEOUT,
CONSUMER_POLLING_TIMEOUT(SDF) {
@Override
Object getDefaultValue() {
return Long.valueOf(2);
}
},
;

@Nonnull private final ImmutableSet<KafkaIOReadImplementation> supportedImplementations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,10 @@ private ReadFromKafkaDoFn(
this.checkStopReadingFn = transform.getCheckStopReadingFn();
this.badRecordRouter = transform.getBadRecordRouter();
this.recordTag = recordTag;
if (transform.getConsumerPollingTimeout() != null) {
this.consumerPollingTimeout =
java.time.Duration.ofMillis(transform.getConsumerPollingTimeout().getMillis());
if (transform.getConsumerPollingTimeout() > 0) {
this.consumerPollingTimeout = transform.getConsumerPollingTimeout();
} else {
this.consumerPollingTimeout = KAFKA_POLL_TIMEOUT;
this.consumerPollingTimeout = DEFAULT_KAFKA_POLL_TIMEOUT;
}
}

Expand All @@ -222,10 +221,8 @@ private ReadFromKafkaDoFn(
private transient @Nullable Map<TopicPartition, KafkaLatestOffsetEstimator> offsetEstimatorCache;

private transient @Nullable LoadingCache<TopicPartition, AverageRecordSize> avgRecordSize;

private static final java.time.Duration KAFKA_POLL_TIMEOUT = java.time.Duration.ofSeconds(2);

@VisibleForTesting final java.time.Duration consumerPollingTimeout;
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;
@VisibleForTesting final long consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
@VisibleForTesting final DeserializerProvider<V> valueDeserializerProvider;
@VisibleForTesting final Map<String, Object> consumerConfig;
Expand Down Expand Up @@ -513,9 +510,9 @@ private ConsumerRecords<byte[], byte[]> poll(
final Stopwatch sw = Stopwatch.createStarted();
long previousPosition = -1;
java.time.Duration elapsed = java.time.Duration.ZERO;
java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout);
while (true) {
final ConsumerRecords<byte[], byte[]> rawRecords =
consumer.poll(consumerPollingTimeout.minus(elapsed));
final ConsumerRecords<byte[], byte[]> rawRecords = consumer.poll(timeout.minus(elapsed));
if (!rawRecords.isEmpty()) {
// return as we have found some entries
return rawRecords;
Expand All @@ -525,11 +522,11 @@ private ConsumerRecords<byte[], byte[]> poll(
return rawRecords;
}
elapsed = sw.elapsed();
if (elapsed.toMillis() >= consumerPollingTimeout.toMillis()) {
if (elapsed.toMillis() >= timeout.toMillis()) {
// timeout is over
LOG.warn(
"No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.",
consumerPollingTimeout.getSeconds());
consumerPollingTimeout);
return rawRecords;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2123,14 +2123,14 @@ public void testSinkMetrics() throws Exception {

@Test(expected = IllegalStateException.class)
public void testWithInvalidConsumerPollingTimeout() {
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(Duration.standardSeconds(-5));
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(-5L);
}

@Test
public void testWithValidConsumerPollingTimeout() {
KafkaIO.Read<Integer, Long> reader =
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(Duration.standardSeconds(15));
assertEquals(15, reader.getConsumerPollingTimeout().getStandardSeconds());
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(15L);
assertEquals(15, reader.getConsumerPollingTimeout());
}

private static void verifyProducerRecords(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,12 @@ public void testConstructorWithPollTimeout() {
ReadSourceDescriptors<String, String> descriptors = makeReadSourceDescriptor(consumer);
// default poll timeout = 1 scond
ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(Duration.ofSeconds(2L), dofnInstance.consumerPollingTimeout);
Assert.assertEquals(2L, dofnInstance.consumerPollingTimeout);
// updated timeout = 5 seconds
descriptors =
descriptors.withConsumerPollingTimeout(org.joda.time.Duration.standardSeconds(5L));
descriptors = descriptors.withConsumerPollingTimeout(5L);
ReadFromKafkaDoFn<String, String> dofnInstanceNew =
ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
Assert.assertEquals(5L, dofnInstanceNew.consumerPollingTimeout);
}

private BoundednessVisitor testBoundedness(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.beam.sdk.io.kafka.KafkaIOUtils;
import org.apache.beam.sdk.io.kafka.TimestampPolicyFactory;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
Expand All @@ -51,6 +52,7 @@
import org.apache.beam.sdk.util.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.sdk.util.construction.SdkComponents;
import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar;
import org.apache.beam.sdk.util.construction.TransformUpgrader;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -102,6 +104,7 @@ static class KafkaIOReadWithMetadataTranslator implements TransformPayloadTransl
.addNullableByteArrayField("key_deserializer_provider")
.addNullableByteArrayField("value_deserializer_provider")
.addNullableByteArrayField("check_stop_reading_fn")
.addNullableInt64Field("consumer_polling_timeout")
.build();

@Override
Expand Down Expand Up @@ -173,7 +176,7 @@ public Row toConfigRow(Read<?, ?> transform) {
if (transform.getStopReadTime() != null) {
fieldValues.put("stop_read_time", transform.getStopReadTime());
}

fieldValues.put("consumer_polling_timeout", transform.getConsumerPollingTimeout());
fieldValues.put(
"is_commit_offset_finalize_enabled", transform.isCommitOffsetsInFinalizeEnabled());
fieldValues.put("is_dynamic_read", transform.isDynamicRead());
Expand Down Expand Up @@ -217,6 +220,13 @@ public Row toConfigRow(Read<?, ?> transform) {

@Override
public Read<?, ?> fromConfigRow(Row configRow, PipelineOptions options) {
String updateCompatibilityBeamVersion =
options.as(StreamingOptions.class).getUpdateCompatibilityVersion();
// We need to set a default 'updateCompatibilityBeamVersion' here since this PipelineOption
// is not correctly passed in for pipelines that use Beam 2.55.0.
// This is fixed for Beam 2.56.0 and later.
updateCompatibilityBeamVersion =
(updateCompatibilityBeamVersion != null) ? updateCompatibilityBeamVersion : "2.55.0";
try {
Read<?, ?> transform = KafkaIO.read();

Expand Down Expand Up @@ -320,6 +330,15 @@ public Row toConfigRow(Read<?, ?> transform) {
transform =
transform.withMaxReadTime(org.joda.time.Duration.millis(maxReadTime.toMillis()));
}
if (TransformUpgrader.compareVersions(updateCompatibilityBeamVersion, "2.56.0") < 0) {
// set to current default
transform = transform.withConsumerPollingTimeout(2L);
} else {
Long consumerPollingTimeout = configRow.getInt64("consumer_polling_timeout");
if (consumerPollingTimeout != null) {
transform = transform.withConsumerPollingTimeout(consumerPollingTimeout);
}
}
Instant startReadTime = configRow.getValue("start_read_time");
if (startReadTime != null) {
transform = transform.withStartReadTime(startReadTime);
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
max_read_time=None,
commit_offset_in_finalize=False,
timestamp_policy=processing_time_policy,
consumer_polling_timeout=None,
consumer_polling_timeout=2,
with_metadata=False,
expansion_service=None,
):
Expand Down

0 comments on commit 8092932

Please sign in to comment.