Skip to content

Commit

Permalink
Spark Runner: Change to use partitioner in GroupNonMergingWindowsFunc…
Browse files Browse the repository at this point in the history
…tions#groupByKeyInGlobalWindow (#32610)

* change GroupNonMergingWindowsFunctions#groupByKeyInGlobalWindow to use partitioner
  • Loading branch information
twosom authored Oct 4, 2024
1 parent d84cfff commit a2710ed
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,18 @@ JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupByKeyInGlobalWindow(
new Tuple2<>(
new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)),
CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder)));
return rawKeyValues
.groupByKey()
.map(
kvs ->
WindowedValue.timestampedValueInGlobalWindow(
KV.of(
CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder),
Iterables.transform(
kvs._2,
encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))),
GlobalWindow.INSTANCE.maxTimestamp(),
PaneInfo.ON_TIME_AND_ONLY_FIRING));

JavaPairRDD<ByteArray, Iterable<byte[]>> grouped =
(partitioner == null) ? rawKeyValues.groupByKey() : rawKeyValues.groupByKey(partitioner);
return grouped.map(
kvs ->
WindowedValue.timestampedValueInGlobalWindow(
KV.of(
CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder),
Iterables.transform(
kvs._2,
encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))),
GlobalWindow.INSTANCE.maxTimestamp(),
PaneInfo.ON_TIME_AND_ONLY_FIRING));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
package org.apache.beam.runners.spark.translation;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Iterator;
Expand All @@ -39,6 +45,9 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Assert;
Expand Down Expand Up @@ -112,6 +121,54 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis
}
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithPartitioner() {
// mocking
Partitioner mockPartitioner = mock(Partitioner.class);
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey(any(Partitioner.class)))
.thenAnswer(
invocation -> {
Partitioner partitioner = invocation.getArgument(0);
assertEquals(partitioner, mockPartitioner);
return mockGrouped;
});
when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class));

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner);

verify(mockRawKeyValues, never()).groupByKey();
verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class));
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithoutPartitioner() {
// mocking
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped);

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, null);

verify(mockRawKeyValues, times(1)).groupByKey();
verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class));
}

private GroupByKeyIterator<String, Integer, GlobalWindow> createGbkIterator()
throws Coder.NonDeterministicException {
return createGbkIterator(
Expand Down

0 comments on commit a2710ed

Please sign in to comment.