Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-3719][VL] Optimize agg func match in getAggRelWithRowConstruct #3819

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class HashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -124,7 +125,7 @@ case class HashAggregateExecTransformer(
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
val aggFunc = expr.aggregateFunction
expr.aggregateFunction match {
aggFunc match {
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
Expand Down Expand Up @@ -245,10 +246,7 @@ case class HashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
generateMergeCompanionNode()
case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr |
_: CovPopulation | _: CovSample | _: First | _: Last | _: MaxMinBy =>
case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
generateMergeCompanionNode()
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
Expand Down Expand Up @@ -279,19 +277,7 @@ case class HashAggregateExecTransformer(
expression => {
val aggregateFunction = expression.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
Expand Down Expand Up @@ -328,10 +314,10 @@ case class HashAggregateExecTransformer(
args: java.lang.Object,
childNodes: JList[ExpressionNode],
rowConstructAttributes: Seq[Attribute],
withNull: Boolean = true): ScalarFunctionNode = {
aggFunc: AggregateFunction): ScalarFunctionNode = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val functionName = ConverterUtils.makeFuncName(
if (withNull) "row_constructor_with_null" else "row_constructor",
VeloxIntermediateData.getRowConstructFuncName(aggFunc),
rowConstructAttributes.map(attr => attr.dataType))
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)

Expand Down Expand Up @@ -367,208 +353,60 @@ case class HashAggregateExecTransformer(
})

for (aggregateExpression <- aggregateExpressions) {
val functionInputAttributes = aggregateExpression.aggregateFunction.inputAggBufferAttributes
val aggregateFunction = aggregateExpression.aggregateFunction
aggregateFunction match {
val aggFunc = aggregateExpression.aggregateFunction
val functionInputAttributes = aggFunc.inputAggBufferAttributes
aggFunc match {
case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial =>
val childNodes = new JArrayList[ExpressionNode](
aggregateFunction.children
.map(
attr => {
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
})
.asJava)
val childNodes = aggFunc.children
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.addAll(childNodes)
case avg: Average =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
s"${aggregateExpression.mode.toString} of Average expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes =
functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args))
.asJava
exprNodes.add(
getRowConstructNode(
args,
childNodes,
functionInputAttributes,
withNull = !avg.dataType.isInstanceOf[DecimalType]))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: First | _: Last | _: MaxMinBy =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
s"${aggregateExpression.mode.toString} of " +
s"${aggregateFunction.getClass.toString} expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.add(getRowConstructNode(args, childNodes, functionInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>

case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
throw new UnsupportedOperationException("Only one input attribute is expected.")

case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
// The process of handling the inconsistency in column types and order between
// Spark and Velox is exactly the opposite of applyExtractStruct.
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 3,
s"${aggregateExpression.mode.toString} mode of" +
s"${aggregateFunction.getClass.toString} expects three input attributes."
)
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = functionInputAttributes.toList.map {
attr =>
val aggExpr: ExpressionTransformer = ExpressionConverter
val newInputAttributes = new ArrayBuffer[Attribute]()
val childNodes = new JArrayList[ExpressionNode]()
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val sparkType = sparkTypes(adjustedOrders(idx))
val attr = functionInputAttributes(adjustedOrders(idx))
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (index == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
.doTransform(args)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
newInputAttributes += attr
aggFuncInputAttrNode
}
index += 1
expressionNode
}.asJava
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: Corr =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 6,
s"${aggregateExpression.mode.toString} mode of Corr expects 6 input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = new JArrayList[ExpressionNode]()
// Velox's Corr order is [ck, n, xMk, yMk, xAvg, yAvg]
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
val aggExpr: ExpressionTransformer = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (order == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
}
index += 1
childNodes.add(expressionNode)
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: CovPopulation | _: CovSample =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 4,
s"${aggregateExpression.mode.toString} mode of" +
s"${aggregateFunction.getClass.toString} expects 4 input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = new JArrayList[ExpressionNode]()
// Velox's Covar order is [ck, n, xAvg, yAvg]
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
val aggExpr: ExpressionTransformer = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (order == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
}
index += 1
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
"Final stage of Average expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.add(
getRowConstructNode(args, childNodes, functionInputAttributes, withNull = false))
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}

case _ =>
if (functionInputAttributes.size != 1) {
throw new UnsupportedOperationException("Only one input attribute is expected.")
}
val childNodes = functionInputAttributes.toList
val childNodes = functionInputAttributes
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
Expand Down Expand Up @@ -602,11 +440,11 @@ case class HashAggregateExecTransformer(
// Create aggregation rel.
val groupingList = new JArrayList[ExpressionNode]()
var colIdx = 0
groupingExpressions.foreach(
_ => {
groupingExpressions.foreach {
_ =>
groupingList.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
})
}

val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
Expand All @@ -619,40 +457,25 @@ case class HashAggregateExecTransformer(
aggFilterList.add(null)
}

val aggregateFunc = aggExpr.aggregateFunction
val aggFunc = aggExpr.aggregateFunction
val childrenNodes = new JArrayList[ExpressionNode]()
aggregateFunc match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy
if aggExpr.mode == PartialMerge | aggExpr.mode == Final =>
aggExpr.mode match {
case PartialMerge | Final =>
// Only occupies one column due to intermediate results are combined
// by previous projection.
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case sum: Sum
if sum.dataType.isInstanceOf[DecimalType] &&
(aggExpr.mode == PartialMerge | aggExpr.mode == Final) =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case _ if aggExpr.mode == PartialMerge | aggExpr.mode == Final =>
aggregateFunc.inputAggBufferAttributes.toList.map(
_ => {
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
aggExpr
})
case _ if aggExpr.mode == Partial =>
aggregateFunc.children.toList.map(
_ => {
case Partial =>
aggFunc.children.foreach {
_ =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
aggExpr
})
case function =>
}
case _ =>
throw new UnsupportedOperationException(
s"$function of ${aggExpr.mode.toString} is not supported.")
s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
}
addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})
RelBuilder.makeAggregateRel(
projectRel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ object VeloxIntermediateData {
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

/**
* Obtain the name of the RowConstruct function, only decimal avg and sum currently require the
* use of row_constructor, while the rest use the Gluten custom modified
* row_constructor_with_null.
*/
def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match {
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor"
case _ => "row_constructor_with_null"
}

object Type {

/**
Expand Down
Loading