Skip to content

Commit

Permalink
optimize agg func match in getAggRelWithRowConstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 23, 2023
1 parent 16abf39 commit 534a319
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class HashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -124,7 +125,7 @@ case class HashAggregateExecTransformer(
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
val aggFunc = expr.aggregateFunction
expr.aggregateFunction match {
aggFunc match {
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
Expand Down Expand Up @@ -245,10 +246,7 @@ case class HashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
generateMergeCompanionNode()
case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr |
_: CovPopulation | _: CovSample | _: First | _: Last | _: MaxMinBy =>
case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
generateMergeCompanionNode()
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
Expand Down Expand Up @@ -279,19 +277,7 @@ case class HashAggregateExecTransformer(
expression => {
val aggregateFunction = expression.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
Expand Down Expand Up @@ -328,10 +314,10 @@ case class HashAggregateExecTransformer(
args: java.lang.Object,
childNodes: JList[ExpressionNode],
rowConstructAttributes: Seq[Attribute],
withNull: Boolean = true): ScalarFunctionNode = {
aggFunc: AggregateFunction): ScalarFunctionNode = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val functionName = ConverterUtils.makeFuncName(
if (withNull) "row_constructor_with_null" else "row_constructor",
VeloxIntermediateData.getRowConstructFuncName(aggFunc),
rowConstructAttributes.map(attr => attr.dataType))
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)

Expand Down Expand Up @@ -367,208 +353,60 @@ case class HashAggregateExecTransformer(
})

for (aggregateExpression <- aggregateExpressions) {
val functionInputAttributes = aggregateExpression.aggregateFunction.inputAggBufferAttributes
val aggregateFunction = aggregateExpression.aggregateFunction
aggregateFunction match {
val aggFunc = aggregateExpression.aggregateFunction
val functionInputAttributes = aggFunc.inputAggBufferAttributes
aggFunc match {
case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial =>
val childNodes = new JArrayList[ExpressionNode](
aggregateFunction.children
.map(
attr => {
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
})
.asJava)
val childNodes = aggFunc.children
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.addAll(childNodes)
case avg: Average =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
s"${aggregateExpression.mode.toString} of Average expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes =
functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args))
.asJava
exprNodes.add(
getRowConstructNode(
args,
childNodes,
functionInputAttributes,
withNull = !avg.dataType.isInstanceOf[DecimalType]))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: First | _: Last | _: MaxMinBy =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
s"${aggregateExpression.mode.toString} of " +
s"${aggregateFunction.getClass.toString} expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.add(getRowConstructNode(args, childNodes, functionInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>

case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
throw new UnsupportedOperationException("Only one input attribute is expected.")

case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
// The process of handling the inconsistency in column types and order between
// Spark and Velox is exactly the opposite of applyExtractStruct.
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 3,
s"${aggregateExpression.mode.toString} mode of" +
s"${aggregateFunction.getClass.toString} expects three input attributes."
)
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = functionInputAttributes.toList.map {
attr =>
val aggExpr: ExpressionTransformer = ExpressionConverter
val newInputAttributes = new ArrayBuffer[Attribute]()
val childNodes = new JArrayList[ExpressionNode]()
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val sparkType = sparkTypes(adjustedOrders(idx))
val attr = functionInputAttributes(adjustedOrders(idx))
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (index == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
.doTransform(args)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
newInputAttributes += attr
aggFuncInputAttrNode
}
index += 1
expressionNode
}.asJava
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: Corr =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 6,
s"${aggregateExpression.mode.toString} mode of Corr expects 6 input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = new JArrayList[ExpressionNode]()
// Velox's Corr order is [ck, n, xMk, yMk, xAvg, yAvg]
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
val aggExpr: ExpressionTransformer = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (order == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
}
index += 1
childNodes.add(expressionNode)
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: CovPopulation | _: CovSample =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 4,
s"${aggregateExpression.mode.toString} mode of" +
s"${aggregateFunction.getClass.toString} expects 4 input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
var index = 0
var newInputAttributes: Seq[Attribute] = Seq()
val childNodes = new JArrayList[ExpressionNode]()
// Velox's Covar order is [ck, n, xAvg, yAvg]
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
val aggExpr: ExpressionTransformer = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
val aggNode = aggExpr.doTransform(args)
val expressionNode = if (order == 0) {
// Cast count from DoubleType into LongType to align with Velox semantics.
newInputAttributes = newInputAttributes :+
attr.copy(attr.name, LongType, attr.nullable, attr.metadata)(
attr.exprId,
attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(LongType, attr.nullable),
aggNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes = newInputAttributes :+ attr
aggNode
}
index += 1
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
"Final stage of Average expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = functionInputAttributes.toList
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.add(
getRowConstructNode(args, childNodes, functionInputAttributes, withNull = false))
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}

case _ =>
if (functionInputAttributes.size != 1) {
throw new UnsupportedOperationException("Only one input attribute is expected.")
}
val childNodes = functionInputAttributes.toList
val childNodes = functionInputAttributes
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
Expand Down Expand Up @@ -602,11 +440,11 @@ case class HashAggregateExecTransformer(
// Create aggregation rel.
val groupingList = new JArrayList[ExpressionNode]()
var colIdx = 0
groupingExpressions.foreach(
_ => {
groupingExpressions.foreach {
_ =>
groupingList.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
})
}

val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
Expand All @@ -619,40 +457,25 @@ case class HashAggregateExecTransformer(
aggFilterList.add(null)
}

val aggregateFunc = aggExpr.aggregateFunction
val aggFunc = aggExpr.aggregateFunction
val childrenNodes = new JArrayList[ExpressionNode]()
aggregateFunc match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy
if aggExpr.mode == PartialMerge | aggExpr.mode == Final =>
aggExpr.mode match {
case PartialMerge | Final =>
// Only occupies one column due to intermediate results are combined
// by previous projection.
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case sum: Sum
if sum.dataType.isInstanceOf[DecimalType] &&
(aggExpr.mode == PartialMerge | aggExpr.mode == Final) =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case _ if aggExpr.mode == PartialMerge | aggExpr.mode == Final =>
aggregateFunc.inputAggBufferAttributes.toList.map(
_ => {
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
aggExpr
})
case _ if aggExpr.mode == Partial =>
aggregateFunc.children.toList.map(
_ => {
case Partial =>
aggFunc.children.foreach {
_ =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
aggExpr
})
case function =>
}
case _ =>
throw new UnsupportedOperationException(
s"$function of ${aggExpr.mode.toString} is not supported.")
s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
}
addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})
RelBuilder.makeAggregateRel(
projectRel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ object VeloxIntermediateData {
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match {
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor"
case _ => "row_constructor_with_null"
}

object Type {

/**
Expand Down

0 comments on commit 534a319

Please sign in to comment.