From 27ab77c97b72bac0614ae4f1ec749047089dca2b Mon Sep 17 00:00:00 2001 From: kevinyhzou <37431499+KevinyhZou@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:10:25 +0800 Subject: [PATCH] [GLUTEN-7759][CH]Fix pre project push down in aggregate (#7779) * fix pre project * add test * another fix * fix ci * fix ci * fix review --------- Co-authored-by: zouyunhe <811-zouyunhe@users.noreply.git.sysop.bigo.sg> --- ...ownAggregatePreProjectionAheadExpand.scala | 26 ++++++++++++++++--- ...enClickHouseTPCHSaltNullParquetSuite.scala | 25 ++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala index a3fab3c954ee..21f1be2f2f2e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan // If there is an expression (not a attribute) in an aggregation function's -// parameters. It will introduce a pr-projection to calculate the expression +// parameters. It will introduce a pre-projection to calculate the expression // at first, and make all the parameters be attributes. // If it's a aggregation with grouping set, this pre-projection is placed after // expand operator. This is not efficient, we cannot move this pre-projection @@ -83,7 +83,7 @@ case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession) val originInputAttributes = aheadProjectExprs.filter(e => isAttributeOrLiteral(e)) val preProjectExprs = aheadProjectExprs.filter(e => !isAttributeOrLiteral(e)) - if (preProjectExprs.length == 0) { + if (preProjectExprs.isEmpty) { return hashAggregate } @@ -93,11 +93,31 @@ case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession) return hashAggregate } + def projectInputExists(expr: Expression, inputs: Seq[Attribute]): Boolean = { + expr.children.foreach { + case a: Attribute => + return inputs.indexOf(a) != -1 + case p: Expression => + return projectInputExists(p, inputs) + case _ => + return true + } + true + } + + val couldPushDown = preProjectExprs.forall { + case p: Expression => projectInputExists(p, rootChild.output) + case _ => true + } + + if (!couldPushDown) { + return hashAggregate; + } + // The new ahead project node will take rootChild's output and preProjectExprs as the // the projection expressions. val aheadProject = ProjectExecTransformer(rootChild.output ++ preProjectExprs, rootChild) val aheadProjectOuput = aheadProject.output - val preProjectOutputAttrs = aheadProjectOuput.filter( e => !originInputAttributes.exists(_.exprId.equals(e.asInstanceOf[NamedExpression].exprId))) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 40b704d2e8d7..12047b300c9c 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3067,5 +3067,30 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) } + + test("GLUTEN-7759: Fix bug of agg pre-project push down") { + val table_create_sql = + "create table test_tbl_7759(id bigint, name string, day string) using parquet" + val insert_data_sql = + "insert into test_tbl_7759 values(1, 'a123', '2024-11-01'),(2, 'a124', '2024-11-01')" + val query_sql = + """ + |select distinct day, name from( + |select '2024-11-01' as day + |,coalesce(name,'all') name + |,cnt + |from + |( + |select count(distinct id) as cnt, name + |from test_tbl_7759 + |group by name + |with cube + |)) limit 10 + |""".stripMargin + spark.sql(table_create_sql) + spark.sql(insert_data_sql) + compareResultsAgainstVanillaSpark(query_sql, true, { _ => }) + spark.sql("drop table test_tbl_7759") + } } // scalastyle:on line.size.limit