From 283c1c111d2d2682c728f681bebc5307d40d70a5 Mon Sep 17 00:00:00 2001 From: abhishekagarwal87 <1477457+abhishekagarwal87@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:49:34 +0530 Subject: [PATCH] [CALCITE-6333] NullPointerException in AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered distinct aggregation Fix test order Fix one more --- ...AggregateExpandDistinctAggregatesRule.java | 9 ++-- .../apache/calcite/test/RelOptRulesTest.java | 21 ++++++++ .../apache/calcite/test/RelOptRulesTest.xml | 54 +++++++++++++++++++ 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java index ccf32284e74..56e224572c1 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java @@ -741,6 +741,9 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, if (!aggCall.getArgList().equals(argList)) { continue; } + if (aggCall.filterArg != filterArg) { + continue; + } // Re-map arguments. final int argCount = aggCall.getArgList().size(); @@ -748,14 +751,10 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, for (Integer arg : aggCall.getArgList()) { newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")")); } - final int newFilterArg = - aggCall.filterArg < 0 ? -1 - : requireNonNull(sourceOf.get(aggCall.filterArg), - () -> "sourceOf.get(" + aggCall.filterArg + ")"); final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList, - newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation, + newArgs, -1, aggCall.distinctKeys, aggCall.collation, aggCall.getType(), aggCall.getName()); assert refs.get(i) == null; if (leftFields == null) { diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 0681d4b7fc6..d02cf05c1d4 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -2175,6 +2175,27 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) { .check(); } + @Test void testDistinctWithFilterWithoutGroupByUsingJoin() { + final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n" + + "FROM emp"; + sql(sql) + .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) + .check(); + } + + @Test void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() { + final String sql = "select deptno, " + + "count(distinct sal) FILTER (WHERE sal > 1000), " + + "count(distinct sal) FILTER (WHERE sal > 500) " + + "from sales.emp group by deptno"; + sql(sql) + .withRule( + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_PROJECT_MERGE + ) + .check(); + } + @Test void testDistinctWithFilterWithoutGroupBy() { final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n" + "FROM emp"; diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 177abc997b6..bb0c2f9ca81 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -2873,6 +2873,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], EXPR$1=[COUNT($0) FILTE LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], $g=[GROUPING($1, $2)]) LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000) +FROM emp]]> + + + ($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)]) + LogicalAggregate(group=[{0}]) + LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)]) + LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> @@ -7279,6 +7305,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3], EXPR$3=[$5]) LogicalProject(SAL=[$0]) LogicalProject(SAL=[$5], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000), count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]> + + + ($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)]) + LogicalAggregate(group=[{0, 1}]) + LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)]) + LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]>