diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java index 0a8e7d8a159b..2461d5cc8d66 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java @@ -275,17 +275,18 @@ JavaRDD>>> groupByKeyInGlobalWindow( new Tuple2<>( new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)), CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder))); - return rawKeyValues - .groupByKey() - .map( - kvs -> - WindowedValue.timestampedValueInGlobalWindow( - KV.of( - CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), - Iterables.transform( - kvs._2, - encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))), - GlobalWindow.INSTANCE.maxTimestamp(), - PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + JavaPairRDD> grouped = + (partitioner == null) ? rawKeyValues.groupByKey() : rawKeyValues.groupByKey(partitioner); + return grouped.map( + kvs -> + WindowedValue.timestampedValueInGlobalWindow( + KV.of( + CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), + Iterables.transform( + kvs._2, + encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))), + GlobalWindow.INSTANCE.maxTimestamp(), + PaneInfo.ON_TIME_AND_ONLY_FIRING)); } } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java index fd299924af91..ed7bc078564e 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java @@ -18,6 +18,12 @@ package org.apache.beam.runners.spark.translation; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.Arrays; import java.util.Iterator; @@ -39,6 +45,9 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes; +import org.apache.spark.Partitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Assert; @@ -112,6 +121,54 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis } } + @Test + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testGroupByKeyInGlobalWindowWithPartitioner() { + // mocking + Partitioner mockPartitioner = mock(Partitioner.class); + JavaRDD mockRdd = mock(JavaRDD.class); + Coder mockKeyCoder = mock(Coder.class); + Coder mockValueCoder = mock(Coder.class); + JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); + JavaPairRDD mockGrouped = mock(JavaPairRDD.class); + + when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); + when(mockRawKeyValues.groupByKey(any(Partitioner.class))) + .thenAnswer( + invocation -> { + Partitioner partitioner = invocation.getArgument(0); + assertEquals(partitioner, mockPartitioner); + return mockGrouped; + }); + when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class)); + + GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( + mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner); + + verify(mockRawKeyValues, never()).groupByKey(); + verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class)); + } + + @Test + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testGroupByKeyInGlobalWindowWithoutPartitioner() { + // mocking + JavaRDD mockRdd = mock(JavaRDD.class); + Coder mockKeyCoder = mock(Coder.class); + Coder mockValueCoder = mock(Coder.class); + JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); + JavaPairRDD mockGrouped = mock(JavaPairRDD.class); + + when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); + when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped); + + GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( + mockRdd, mockKeyCoder, mockValueCoder, null); + + verify(mockRawKeyValues, times(1)).groupByKey(); + verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class)); + } + private GroupByKeyIterator createGbkIterator() throws Coder.NonDeterministicException { return createGbkIterator(