From 283c1c111d2d2682c728f681bebc5307d40d70a5 Mon Sep 17 00:00:00 2001
From: abhishekagarwal87 <1477457+abhishekagarwal87@users.noreply.github.com>
Date: Fri, 15 Mar 2024 15:49:34 +0530
Subject: [PATCH] [CALCITE-6333] NullPointerException in
AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered
distinct aggregation
Fix test order
Fix one more
---
...AggregateExpandDistinctAggregatesRule.java | 9 ++--
.../apache/calcite/test/RelOptRulesTest.java | 21 ++++++++
.../apache/calcite/test/RelOptRulesTest.xml | 54 +++++++++++++++++++
3 files changed, 79 insertions(+), 5 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index ccf32284e74..56e224572c1 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -741,6 +741,9 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
if (!aggCall.getArgList().equals(argList)) {
continue;
}
+ if (aggCall.filterArg != filterArg) {
+ continue;
+ }
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
@@ -748,14 +751,10 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
for (Integer arg : aggCall.getArgList()) {
newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
}
- final int newFilterArg =
- aggCall.filterArg < 0 ? -1
- : requireNonNull(sourceOf.get(aggCall.filterArg),
- () -> "sourceOf.get(" + aggCall.filterArg + ")");
final AggregateCall newAggCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
- newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
+ newArgs, -1, aggCall.distinctKeys, aggCall.collation,
aggCall.getType(), aggCall.getName());
assert refs.get(i) == null;
if (leftFields == null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 0681d4b7fc6..d02cf05c1d4 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2175,6 +2175,27 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}
+ @Test void testDistinctWithFilterWithoutGroupByUsingJoin() {
+ final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n"
+ + "FROM emp";
+ sql(sql)
+ .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+ .check();
+ }
+
+ @Test void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() {
+ final String sql = "select deptno, "
+ + "count(distinct sal) FILTER (WHERE sal > 1000), "
+ + "count(distinct sal) FILTER (WHERE sal > 500) "
+ + "from sales.emp group by deptno";
+ sql(sql)
+ .withRule(
+ CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN,
+ CoreRules.AGGREGATE_PROJECT_MERGE
+ )
+ .check();
+ }
+
@Test void testDistinctWithFilterWithoutGroupBy() {
final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n"
+ "FROM emp";
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 177abc997b6..bb0c2f9ca81 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2873,6 +2873,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], EXPR$1=[COUNT($0) FILTE
LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], $g=[GROUPING($1, $2)])
LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+ 1000)
+FROM emp]]>
+
+
+ ($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+ ($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
@@ -7279,6 +7305,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3], EXPR$3=[$5])
LogicalProject(SAL=[$0])
LogicalProject(SAL=[$5], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+ 1000), count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]>
+
+
+ ($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+ ($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)])
+ LogicalAggregate(group=[{0, 1}])
+ LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>