diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala index 3c9bc38b33ba..b106b3f1a5dd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownAggregatePreProjectionAheadExpand.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan // If there is an expression (not a attribute) in an aggregation function's -// parameters. It will introduce a pr-projection to calculate the expression +// parameters. It will introduce a pre-projection to calculate the expression // at first, and make all the parameters be attributes. // If it's a aggregation with grouping set, this pre-projection is placed after // expand operator. This is not efficient, we cannot move this pre-projection @@ -83,7 +83,7 @@ case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession) val originInputAttributes = aheadProjectExprs.filter(e => isAttributeOrLiteral(e)) val preProjectExprs = aheadProjectExprs.filter(e => !isAttributeOrLiteral(e)) - if (preProjectExprs.isEmpty || originInputAttributes.nonEmpty) { + if (preProjectExprs.isEmpty) { return hashAggregate } @@ -93,29 +93,32 @@ case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession) return hashAggregate } - def replaceProjectInputs(expr: Expression, inputs: Seq[Attribute]): Expression = { - var newChildren = Seq.empty[Expression] + def projectInputExists(expr: Expression, inputs: Seq[Attribute]): Boolean = { expr.children.foreach { case a: Attribute => - for (input <- inputs if input.name.equals(a.name)) { - newChildren :+= input + var exist = false + for (input <- inputs if input.name.equals(a.name) && input.exprId.equals(a.exprId)) { + exist = true } + return exist case p: Expression => - val newChild = replaceProjectInputs(p, inputs) - newChildren :+= newChild + return projectInputExists(p, inputs) case _ => + return true } - expr.withNewChildren(newChildren) + true } - var newProjectExprs = Seq.empty[NamedExpression] + preProjectExprs.foreach( - p => - newProjectExprs :+= replaceProjectInputs(p, rootChild.output) - .asInstanceOf[NamedExpression]) + p => { + if (!projectInputExists(p, rootChild.output)) { + return hashAggregate + } + }) // The new ahead project node will take rootChild's output and preProjectExprs as the // the projection expressions. - val aheadProject = ProjectExecTransformer(rootChild.output ++ newProjectExprs, rootChild) + val aheadProject = ProjectExecTransformer(rootChild.output ++ preProjectExprs, rootChild) val aheadProjectOuput = aheadProject.output val preProjectOutputAttrs = aheadProjectOuput.filter( e =>