Skip to content

Commit

Permalink
[GLUTEN-3644][CH] Revert the logic to support the custom aggregate fu…
Browse files Browse the repository at this point in the history
…nctions

In PR apache#3629, it removes the logic to support the custom aggregate functions, must be reverted.

Close apache#3644.
  • Loading branch information
zzcclp committed Nov 8, 2023
1 parent 4a72871 commit 0ca627c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package io.glutenproject.execution.extension
import io.glutenproject.expression._
import io.glutenproject.extension.ExpressionExtensionTrait

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._

import scala.collection.mutable.ListBuffer

case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

lazy val expressionSigs = Seq(
Expand All @@ -29,4 +32,33 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = expressionSigs

/** Get the attribute index of the extension aggregate functions. */
override def getAttrsIndexForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
var reIndex = resIndex
aggregateFunc match {
case CustomSum(_, _) =>
mode match {
case Partial =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
reIndex += 1
reIndex
// custom logic: can not support 'Final'
/* case Final =>
aggregateAttr += aggregateAttributeList(reIndex)
reIndex += 1
reIndex */
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
*/
package io.glutenproject.execution.extension

import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer}
import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, WholeStageTransformerSuite}

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import org.apache.spark.sql.catalyst.expressions.aggregate.CustomSum
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec

class GlutenCustomAggExpressionSuite
extends GlutenClickHouseTPCHAbstractSuite
with AdaptiveSparkPlanHelper {
class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite {

override protected val resourcePath: String =
"../../../../gluten-core/src/test/resources/tpch-data"
Expand Down Expand Up @@ -77,19 +75,13 @@ class GlutenCustomAggExpressionSuite
| l_returnflag,
| l_linestatus;
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val hashAggExec = collect(df.queryExecution.executedPlan) {
case hash: HashAggregateExecBaseTransformer => hash
}
assert(hashAggExec.size == 2)
val df = spark.sql(sql)
// Final stage is not supported, it will be fallback
WholeStageTransformerSuite.checkFallBack(df, false)

assert(hashAggExec(0).aggregateExpressions(0).aggregateFunction.isInstanceOf[CustomSum])
assert(hashAggExec(1).aggregateExpressions(0).aggregateFunction.isInstanceOf[CustomSum])
}
)
val fallbackAggExec = df.queryExecution.executedPlan.collect {
case agg: HashAggregateExec => agg
}
assert(fallbackAggExec.size == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,25 +408,40 @@ abstract class HashAggregateExecBaseTransformer(
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
// First handle the custom aggregate functions
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer
.getAttrsIndexForExtensionAggregateExpr(
aggregateFunc,
mode,
exp,
aggregateAttributeList,
aggregateAttr,
index)
} else {
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode}

import scala.collection.mutable.ListBuffer

trait ExpressionExtensionTrait {

Expand All @@ -37,6 +40,18 @@ trait ExpressionExtensionTrait {
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.")
}

/** Get the attribute index of the extension aggregate functions. */
def getAttrsIndexForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
Expand Down

0 comments on commit 0ca627c

Please sign in to comment.