From c41230067013016920cb2212559de0a4b26c9f48 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Wed, 13 Mar 2024 14:31:17 +0800 Subject: [PATCH] Refine the code and add tests --- velox/exec/Window.cpp | 34 ++++++------ .../window/tests/AggregateWindowTest.cpp | 53 +++++++++++++++++-- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 4dbd531796e9..3ef2181c99d3 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -320,24 +320,26 @@ void Window::updateKRowsFrameBounds( // Considers a very large int64 constantOffset is used. if (startValue < std::numeric_limits::min()) { std::fill_n(rawFrameBounds, numRows, startRow); - } else { - // Integer overflow cannot happen. - std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); - } - } else { - auto overflowStart = getOverflowStart(constantOffset); - if (overflowStart >= 0 && overflowStart < numRows) { - std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue); - // For remaining, set with the largest index for this partition. - std::fill_n( - rawFrameBounds + overflowStart, - numRows - overflowStart, - startRow + numRows - 1); - } else { - // Integer overflow cannot happen. - std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; } + // Integer overflow cannot happen. + std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; + } + // KFollowing. + auto overflowStart = getOverflowStart(constantOffset); + if (overflowStart >= 0 && overflowStart < numRows) { + std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue); + // For remaining, set with the largest index for this partition. + std::fill_n( + rawFrameBounds + overflowStart, + numRows - overflowStart, + startRow + numRows - 1); + return; } + // Integer overflow cannot happen. + std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; } else { currentPartition_->extractColumn( frameArg.index, partitionOffset_, numRows, 0, frameArg.value); diff --git a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp index 65d5cc0bf6e5..cc8999b1caa3 100644 --- a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp +++ b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp @@ -228,7 +228,7 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { auto input = makeRowVector({c0, c1, c2, c3}); std::string overClause = "partition by c0 order by c1 desc"; - // Test with literal larger than INT32_MAX (2147483647). + // Test constant following larger than INT32_MAX (2147483647). std::string frameClause = "rows between 0 preceding and 2147483650 following"; auto expected = makeRowVector( {c0, @@ -239,7 +239,19 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); - // Test integer overflow with column-specified following (int32). + // Test overflow case that happens during the calculation for the middle + // partition. + frameClause = "rows between 0 preceding and 2147483645 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Test with column-specified following (int32). frameClause = "rows between 0 preceding and c2 following"; expected = makeRowVector( {c0, @@ -250,7 +262,7 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); - // Test integer overflow with column-specified following (int64). + // Test with column-specified following (int64). frameClause = "rows between 0 preceding and c3 following"; expected = makeRowVector( {c0, @@ -261,7 +273,29 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); - // Test integer overflow with column-specified preceding (int64). + // Test constant preceding larger than INT32_MAX. + frameClause = "rows between 2147483650 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Test with column-specified preceding (int32). + frameClause = "rows between c2 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 2, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Test with column-specified preceding (int64). frameClause = "rows between c3 preceding and 0 following"; expected = makeRowVector( {c0, @@ -271,6 +305,17 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); + + // Test constant preceding & following both larger than INT32_MAX. + frameClause = "rows between 2147483650 preceding and 2147483651 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 6, 6, 6, 6, 6, 4, 4, 4, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); } }; // namespace