Skip to content

Commit

Permalink
[GLUTEN-3705][FOLLOW][CH] Set the correct agg schema names after mapp…
Browse files Browse the repository at this point in the history
…ing one agg function to the other name (#3734)

With CH backend, in the final stage, the agg schema names must be the `agg_function#exprId#Partial#custom_sum`,
after mapping the agg function to the other name, it does not modify according to the new one.
  • Loading branch information
zzcclp authored Nov 16, 2023
1 parent 808e091 commit c5ae59d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,19 @@ case class CHHashAggregateExecTransformer(
ConverterUtils.genColumnNameWithExprId(resultAttr)
} else {
val aggExpr = aggExpressions(columnIndex - groupingExprs.length)
val aggregateFunc = aggExpr.aggregateFunction
var aggFunctionName =
AggregateFunctionsBuilder.getSubstraitFunctionName(aggExpr.aggregateFunction).get
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer
.buildCustomAggregateFunction(aggregateFunc)
._1
.get
} else {
AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc).get
}
ConverterUtils.genColumnNameWithExprId(resultAttr) + "#Partial#" + aggFunctionName
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {
aggregateFunc match {
case CustomSum(_, _) =>
mode match {
case Partial =>
// custom logic: can not support 'Partial'
/* case Partial =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
reIndex += 1
reIndex
// custom logic: can not support 'Final'
/* case Final =>
reIndex */
case Final =>
aggregateAttr += aggregateAttributeList(reIndex)
reIndex += 1
reIndex */
reIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
Expand All @@ -74,10 +74,12 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {
Some("custom_sum_double")
}
case _ =>
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
extensionExpressionsMapping.get(aggregateFunc.getClass)
}
if (substraitAggFuncName.isEmpty) {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}

(substraitAggFuncName, aggregateFunc.children.map(child => child.dataType))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/
package io.glutenproject.execution.extension

import io.glutenproject.execution.{CHHashAggregateExecTransformer, GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer, WholeStageTransformerSuite}
import io.glutenproject.execution._
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.utils.SubstraitPlanPrinterUtil

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.FunctionIdentifier
Expand Down Expand Up @@ -81,18 +82,27 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite {
// Final stage is not supported, it will be fallback
WholeStageTransformerSuite.checkFallBack(df, false)

val aggExecs = df.queryExecution.executedPlan.collect {
val planExecs = df.queryExecution.executedPlan.collect {
case agg: HashAggregateExec => agg
case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
case wholeStage: WholeStageTransformer => wholeStage
}

assert(aggExecs(0).isInstanceOf[HashAggregateExec])
// First stage fallback
assert(planExecs(3).isInstanceOf[HashAggregateExec])

val substraitContext = new SubstraitContext
aggExecs(1).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext)
planExecs(2).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext)

// Check the functions
assert(substraitContext.registeredFunction.containsKey("custom_sum_double:req_fp64"))
assert(substraitContext.registeredFunction.containsKey("custom_sum:req_i64"))
assert(substraitContext.registeredFunction.containsKey("sum:req_fp64"))

val wx = planExecs(1).asInstanceOf[WholeStageTransformer].doWholeStageTransform()
val planJson = SubstraitPlanPrinterUtil.substraitPlanToJson(wx.root.toProtobuf)
assert(planJson.contains("#Partial#custom_sum_double"))
assert(planJson.contains("#Partial#custom_sum"))
assert(planJson.contains("#Partial#sum"))
}
}

0 comments on commit c5ae59d

Please sign in to comment.