diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java index 7a00db8d72cf..1e3ec531dd5a 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java @@ -734,6 +734,9 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, if (!aggCall.getArgList().equals(argList)) { continue; } + if (aggCall.filterArg != filterArg) { + continue; + } // Re-map arguments. final int argCount = aggCall.getArgList().size(); @@ -741,14 +744,10 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, for (Integer arg : aggCall.getArgList()) { newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")")); } - final int newFilterArg = - aggCall.filterArg < 0 ? -1 - : requireNonNull(sourceOf.get(aggCall.filterArg), - () -> "sourceOf.get(" + aggCall.filterArg + ")"); final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList, - newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation, + newArgs, -1, aggCall.distinctKeys, aggCall.collation, aggCall.getType(), aggCall.getName()); assert refs.get(i) == null; if (leftFields == null) { diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 51a13e0bc50a..86811ab186b2 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -2149,6 +2149,30 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) { .check(); } + @Test void testDistinctWithFilterWithoutGroupByUsingJoin() { + final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n" + + "FROM emp"; + sql(sql) + .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) + .check(); + } + + @Test + void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() { + final String sql = "select deptno, " + + "count(distinct sal) FILTER (WHERE sal > 1000), " + + "count(distinct sal) FILTER (WHERE sal > 500) " + + "from sales.emp group by deptno"; + sql(sql) + .withRule( + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_PROJECT_MERGE + ) + .check(); + + testSortUnionTranspose(); + } + @Test void testDistinctWithFilterWithoutGroupBy() { final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n" + "FROM emp"; diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index bc940b804c8f..8bea453b70eb 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -2856,6 +2856,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], EXPR$1=[COUNT($0) FILTE LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], $g=[GROUPING($1, $2)]) LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000) +FROM emp]]> + + + ($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)]) + LogicalAggregate(group=[{0}]) + LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)]) + LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> @@ -7247,6 +7273,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3], EXPR$3=[$5]) LogicalProject(SAL=[$0]) LogicalProject(SAL=[$5], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000), count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]> + + + ($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)]) + LogicalAggregate(group=[{0, 1}]) + LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)]) + LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]>