diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 757795bf86343..f069b0e8a235e 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -19,8 +19,11 @@ package io.glutenproject.execution.extension import io.glutenproject.expression._ import io.glutenproject.extension.ExpressionExtensionTrait +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import scala.collection.mutable.ListBuffer + case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { lazy val expressionSigs = Seq( @@ -29,4 +32,42 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ override def expressionSigList: Seq[Sig] = expressionSigs + + /** Get the attribute index of the extension aggregate functions. */ + override def getAttrsIndexForExtensionAggregateExpr( + aggregateFunc: AggregateFunction, + mode: AggregateMode, + exp: AggregateExpression, + aggregateAttributeList: Seq[Attribute], + aggregateAttr: ListBuffer[Attribute], + resIndex: Int): Int = { + var reIndex = resIndex + aggregateFunc match { + case CustomSum(_, _) => + mode match { + // custom logic: can not support 'PartialMerge' + case Partial => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + if (aggBufferAttr.size == 2) { + // decimal sum check sum.resultType + aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1)) + aggregateAttr += isEmptyAttr + reIndex += 2 + reIndex + } else { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + aggregateAttr += attr + reIndex += 1 + reIndex + } + case Final => + aggregateAttr += aggregateAttributeList(reIndex) + reIndex += 1 + reIndex + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + } + } + } } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index e88b87e7a4a0e..25298c53f710d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -408,25 +408,40 @@ abstract class HashAggregateExecBaseTransformer( var resIndex = index val mode = exp.mode val aggregateFunc = exp.aggregateFunction - if (!checkAggFuncModeSupport(aggregateFunc, mode)) { - throw new UnsupportedOperationException( - s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}") - } - mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += aggBufferAttr.size - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + // First handle the custom aggregate functions + if ( + ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + ExpressionMappings.expressionExtensionTransformer + .getAttrsIndexForExtensionAggregateExpr( + aggregateFunc, + mode, + exp, + aggregateAttributeList, + aggregateAttr, + index) + } else { + if (!checkAggFuncModeSupport(aggregateFunc, mode)) { + throw new UnsupportedOperationException( + s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}") + } + mode match { + case Partial | PartialMerge => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + aggregateAttr += attr + } + resIndex += aggBufferAttr.size + resIndex + case Final => + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 + resIndex + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + } } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala index 5a5fb08c3c91e..d1fc219b7bbbf 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala @@ -20,6 +20,9 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode} + +import scala.collection.mutable.ListBuffer trait ExpressionExtensionTrait { @@ -37,6 +40,18 @@ trait ExpressionExtensionTrait { attributeSeq: Seq[Attribute]): ExpressionTransformer = { throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.") } + + /** Get the attribute index of the extension aggregate functions. */ + def getAttrsIndexForExtensionAggregateExpr( + aggregateFunc: AggregateFunction, + mode: AggregateMode, + exp: AggregateExpression, + aggregateAttributeList: Seq[Attribute], + aggregateAttr: ListBuffer[Attribute], + resIndex: Int): Int = { + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") + } } case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {