diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 757795bf8634..43decfba28b7 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -19,8 +19,11 @@ package io.glutenproject.execution.extension import io.glutenproject.expression._ import io.glutenproject.extension.ExpressionExtensionTrait +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import scala.collection.mutable.ListBuffer + case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { lazy val expressionSigs = Seq( @@ -29,4 +32,33 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ override def expressionSigList: Seq[Sig] = expressionSigs + + /** Get the attribute index of the extension aggregate functions. */ + override def getAttrsIndexForExtensionAggregateExpr( + aggregateFunc: AggregateFunction, + mode: AggregateMode, + exp: AggregateExpression, + aggregateAttributeList: Seq[Attribute], + aggregateAttr: ListBuffer[Attribute], + resIndex: Int): Int = { + var reIndex = resIndex + aggregateFunc match { + case CustomSum(_, _) => + mode match { + case Partial => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + aggregateAttr += attr + reIndex += 1 + reIndex + // custom logic: can not support 'Final' + /* case Final => + aggregateAttr += aggregateAttributeList(reIndex) + reIndex += 1 + reIndex */ + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + } + } + } } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala index 25a7c7d012ab..3a52e0e7b148 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala @@ -16,17 +16,15 @@ */ package io.glutenproject.execution.extension -import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer} +import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, WholeStageTransformerSuite} import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.expressions.aggregate.CustomSum -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.HashAggregateExec -class GlutenCustomAggExpressionSuite - extends GlutenClickHouseTPCHAbstractSuite - with AdaptiveSparkPlanHelper { +class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite { override protected val resourcePath: String = "../../../../gluten-core/src/test/resources/tpch-data" @@ -77,19 +75,13 @@ class GlutenCustomAggExpressionSuite | l_returnflag, | l_linestatus; |""".stripMargin - compareResultsAgainstVanillaSpark( - sql, - true, - { - df => - val hashAggExec = collect(df.queryExecution.executedPlan) { - case hash: HashAggregateExecBaseTransformer => hash - } - assert(hashAggExec.size == 2) + val df = spark.sql(sql) + // Final stage is not supported, it will be fallback + WholeStageTransformerSuite.checkFallBack(df, false) - assert(hashAggExec(0).aggregateExpressions(0).aggregateFunction.isInstanceOf[CustomSum]) - assert(hashAggExec(1).aggregateExpressions(0).aggregateFunction.isInstanceOf[CustomSum]) - } - ) + val fallbackAggExec = df.queryExecution.executedPlan.collect { + case agg: HashAggregateExec => agg + } + assert(fallbackAggExec.size == 1) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index e88b87e7a4a0..25298c53f710 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -408,25 +408,40 @@ abstract class HashAggregateExecBaseTransformer( var resIndex = index val mode = exp.mode val aggregateFunc = exp.aggregateFunction - if (!checkAggFuncModeSupport(aggregateFunc, mode)) { - throw new UnsupportedOperationException( - s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}") - } - mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += aggBufferAttr.size - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + // First handle the custom aggregate functions + if ( + ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + ExpressionMappings.expressionExtensionTransformer + .getAttrsIndexForExtensionAggregateExpr( + aggregateFunc, + mode, + exp, + aggregateAttributeList, + aggregateAttr, + index) + } else { + if (!checkAggFuncModeSupport(aggregateFunc, mode)) { + throw new UnsupportedOperationException( + s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}") + } + mode match { + case Partial | PartialMerge => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + aggregateAttr += attr + } + resIndex += aggBufferAttr.size + resIndex + case Final => + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 + resIndex + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + } } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala index 5a5fb08c3c91..d1fc219b7bbb 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala @@ -20,6 +20,9 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode} + +import scala.collection.mutable.ListBuffer trait ExpressionExtensionTrait { @@ -37,6 +40,18 @@ trait ExpressionExtensionTrait { attributeSeq: Seq[Attribute]): ExpressionTransformer = { throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.") } + + /** Get the attribute index of the extension aggregate functions. */ + def getAttrsIndexForExtensionAggregateExpr( + aggregateFunc: AggregateFunction, + mode: AggregateMode, + exp: AggregateExpression, + aggregateAttributeList: Seq[Attribute], + aggregateAttr: ListBuffer[Attribute], + resIndex: Int): Int = { + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") + } } case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {