diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index 7a00db8d72cf..1e3ec531dd5a 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -734,6 +734,9 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
if (!aggCall.getArgList().equals(argList)) {
continue;
}
+ if (aggCall.filterArg != filterArg) {
+ continue;
+ }
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
@@ -741,14 +744,10 @@ private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
for (Integer arg : aggCall.getArgList()) {
newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
}
- final int newFilterArg =
- aggCall.filterArg < 0 ? -1
- : requireNonNull(sourceOf.get(aggCall.filterArg),
- () -> "sourceOf.get(" + aggCall.filterArg + ")");
final AggregateCall newAggCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
- newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
+ newArgs, -1, aggCall.distinctKeys, aggCall.collation,
aggCall.getType(), aggCall.getName());
assert refs.get(i) == null;
if (leftFields == null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 51a13e0bc50a..86811ab186b2 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2149,6 +2149,30 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}
+ @Test void testDistinctWithFilterWithoutGroupByUsingJoin() {
+ final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n"
+ + "FROM emp";
+ sql(sql)
+ .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+ .check();
+ }
+
+ @Test
+ void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() {
+ final String sql = "select deptno, "
+ + "count(distinct sal) FILTER (WHERE sal > 1000), "
+ + "count(distinct sal) FILTER (WHERE sal > 500) "
+ + "from sales.emp group by deptno";
+ sql(sql)
+ .withRule(
+ CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN,
+ CoreRules.AGGREGATE_PROJECT_MERGE
+ )
+ .check();
+
+ testSortUnionTranspose();
+ }
+
@Test void testDistinctWithFilterWithoutGroupBy() {
final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)\n"
+ "FROM emp";
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index bc940b804c8f..8bea453b70eb 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2856,6 +2856,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], EXPR$1=[COUNT($0) FILTE
LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], $g=[GROUPING($1, $2)])
LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+ 1000)
+FROM emp]]>
+
+
+ ($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+ ($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
@@ -7247,6 +7273,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3], EXPR$3=[$5])
LogicalProject(SAL=[$0])
LogicalProject(SAL=[$5], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+ 1000), count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]>
+
+
+ ($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+ ($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)])
+ LogicalAggregate(group=[{0, 1}])
+ LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>