diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkCombineFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkCombineFn.java index ddf4b12bae13..1075ae0d2a7d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkCombineFn.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkCombineFn.java @@ -41,7 +41,6 @@ import org.apache.beam.runners.spark.util.SideInputBroadcast; import org.apache.beam.runners.spark.util.SparkSideInputReader; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.CombineWithContext; @@ -101,7 +100,7 @@ void add(WindowedValue value, SparkCombineFn throws Exception; /** - * Merge other acccumulator into this one. + * Merge other accumulator into this one. * * @param other the other accumulator to merge */ @@ -173,7 +172,7 @@ static SingleWindowWindowedAccumulator(toValue); } - static WindowedAccumulator create( + static SingleWindowWindowedAccumulator create( Function toValue, WindowedValue accumulator) { return new SingleWindowWindowedAccumulator<>(toValue, accumulator); } @@ -191,10 +190,7 @@ static SingleWindowWindowedAccumulator toValue, WindowedValue accumulator) { this.toValue = toValue; this.windowAccumulator = accumulator.getValue(); - this.accTimestamp = - accumulator.getTimestamp().equals(BoundedWindow.TIMESTAMP_MIN_VALUE) - ? null - : accumulator.getTimestamp(); + this.accTimestamp = accumulator.getTimestamp(); this.accWindow = getWindow(accumulator); } @@ -247,7 +243,7 @@ public void merge( @Override public Collection> extractOutput() { if (windowAccumulator != null) { - return Arrays.asList( + return Collections.singletonList( WindowedValue.of( windowAccumulator, accTimestamp, accWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING)); } @@ -516,7 +512,8 @@ static class WindowedAccumulatorCoder @Override public void encode(WindowedAccumulator value, OutputStream outStream) - throws CoderException, IOException { + throws IOException { + if (type.isMapBased()) { wrap.encode(((MapBasedWindowedAccumulator) value).map.values(), outStream); } else { @@ -536,7 +533,8 @@ public void encode(WindowedAccumulator value, OutputS @Override public WindowedAccumulator decode(InputStream inStream) - throws CoderException, IOException { + throws IOException { + if (type.isMapBased()) { return WindowedAccumulator.create(toValue, type, wrap.decode(inStream), windowComparator); } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkCombineFnTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkCombineFnTest.java index 295b7ef2b948..9cb4b44c897c 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkCombineFnTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkCombineFnTest.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.WindowedValue; @@ -219,6 +220,34 @@ public void testSlidingCombineFnExplode() throws Exception { result); } + @Test + public void testGlobalWindowMergeAccumulatorsWithEarliestCombiner() throws Exception { + SparkCombineFn, Integer, Long, Long> sparkCombineFn = + SparkCombineFn.keyed( + combineFn, + opts, + Collections.emptyMap(), + WindowingStrategy.globalDefault().withTimestampCombiner(TimestampCombiner.EARLIEST)); + + Instant ts = BoundedWindow.TIMESTAMP_MIN_VALUE; + WindowedValue> first = input("key", 1, ts); + WindowedValue> second = input("key", 2, ts); + WindowedValue> third = input("key", 3, ts); + WindowedValue accumulator = WindowedValue.valueInGlobalWindow(0L); + SparkCombineFn.SingleWindowWindowedAccumulator, Integer, Long> acc1 = + SparkCombineFn.SingleWindowWindowedAccumulator.create(KV::getValue, accumulator); + SparkCombineFn.SingleWindowWindowedAccumulator, Integer, Long> acc2 = + SparkCombineFn.SingleWindowWindowedAccumulator.create(KV::getValue, accumulator); + SparkCombineFn.SingleWindowWindowedAccumulator, Integer, Long> acc3 = + SparkCombineFn.SingleWindowWindowedAccumulator.create(KV::getValue, accumulator); + acc1.add(first, sparkCombineFn); + acc2.add(second, sparkCombineFn); + acc3.merge(acc1, sparkCombineFn); + acc3.merge(acc2, sparkCombineFn); + acc3.add(third, sparkCombineFn); + assertEquals(6, (long) Iterables.getOnlyElement(sparkCombineFn.extractOutput(acc3)).getValue()); + } + private static Combine.CombineFn getSumFn() { return new Combine.CombineFn() {