diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index 2ebb80469e46..ad997843eee6 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -327,8 +327,19 @@ case class CHHashAggregateExecTransformer( ConverterUtils.genColumnNameWithExprId(resultAttr) } else { val aggExpr = aggExpressions(columnIndex - groupingExprs.length) + val aggregateFunc = aggExpr.aggregateFunction var aggFunctionName = - AggregateFunctionsBuilder.getSubstraitFunctionName(aggExpr.aggregateFunction).get + if ( + ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + ExpressionMappings.expressionExtensionTransformer + .buildCustomAggregateFunction(aggregateFunc) + ._1 + .get + } else { + AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc).get + } ConverterUtils.genColumnNameWithExprId(resultAttr) + "#Partial#" + aggFunctionName } } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 3e2e44426e23..76ec0d7ad1fb 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -46,17 +46,17 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { aggregateFunc match { case CustomSum(_, _) => mode match { - case Partial => + // custom logic: can not support 'Partial' + /* case Partial => val aggBufferAttr = aggregateFunc.inputAggBufferAttributes val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) aggregateAttr += attr reIndex += 1 - reIndex - // custom logic: can not support 'Final' - /* case Final => + reIndex */ + case Final => aggregateAttr += aggregateAttributeList(reIndex) reIndex += 1 - reIndex */ + reIndex case other => throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") } @@ -74,10 +74,12 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { Some("custom_sum_double") } case _ => - throw new UnsupportedOperationException( - s"Aggregate function ${aggregateFunc.getClass} is not supported.") + extensionExpressionsMapping.get(aggregateFunc.getClass) + } + if (substraitAggFuncName.isEmpty) { + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") } - (substraitAggFuncName, aggregateFunc.children.map(child => child.dataType)) } } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala index f38cb712160f..d70010957bec 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala @@ -16,8 +16,9 @@ */ package io.glutenproject.execution.extension -import io.glutenproject.execution.{CHHashAggregateExecTransformer, GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer, WholeStageTransformerSuite} +import io.glutenproject.execution._ import io.glutenproject.substrait.SubstraitContext +import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.FunctionIdentifier @@ -81,18 +82,27 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite { // Final stage is not supported, it will be fallback WholeStageTransformerSuite.checkFallBack(df, false) - val aggExecs = df.queryExecution.executedPlan.collect { + val planExecs = df.queryExecution.executedPlan.collect { case agg: HashAggregateExec => agg case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer + case wholeStage: WholeStageTransformer => wholeStage } - assert(aggExecs(0).isInstanceOf[HashAggregateExec]) + // First stage fallback + assert(planExecs(3).isInstanceOf[HashAggregateExec]) + val substraitContext = new SubstraitContext - aggExecs(1).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext) + planExecs(2).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext) // Check the functions assert(substraitContext.registeredFunction.containsKey("custom_sum_double:req_fp64")) assert(substraitContext.registeredFunction.containsKey("custom_sum:req_i64")) assert(substraitContext.registeredFunction.containsKey("sum:req_fp64")) + + val wx = planExecs(1).asInstanceOf[WholeStageTransformer].doWholeStageTransform() + val planJson = SubstraitPlanPrinterUtil.substraitPlanToJson(wx.root.toProtobuf) + assert(planJson.contains("#Partial#custom_sum_double")) + assert(planJson.contains("#Partial#custom_sum")) + assert(planJson.contains("#Partial#sum")) } }