diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 2d12eae0d41ff..bd7b601f23ffe 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -37,6 +37,7 @@ import java.lang.{Long => JLong} import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer case class HashAggregateExecTransformer( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -124,7 +125,7 @@ case class HashAggregateExecTransformer( throw new UnsupportedOperationException(s"${expr.mode} not supported.") } val aggFunc = expr.aggregateFunction - expr.aggregateFunction match { + aggFunc match { case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) => val (sparkOrders, sparkTypes) = aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip @@ -245,10 +246,7 @@ case class HashAggregateExecTransformer( case other => throw new UnsupportedOperationException(s"$other is not supported.") } - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - generateMergeCompanionNode() - case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | - _: CovPopulation | _: CovSample | _: First | _: Last | _: MaxMinBy => + case _ if aggregateFunction.aggBufferAttributes.size > 1 => generateMergeCompanionNode() case _ => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( @@ -279,19 +277,7 @@ case class HashAggregateExecTransformer( expression => { val aggregateFunction = expression.aggregateFunction aggregateFunction match { - case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => - expression.mode match { - case Partial | PartialMerge => - typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) - case Final => - typeNodeList.add( - ConverterUtils - .getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => + case _ if aggregateFunction.aggBufferAttributes.size > 1 => expression.mode match { case Partial | PartialMerge => typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) @@ -328,10 +314,10 @@ case class HashAggregateExecTransformer( args: java.lang.Object, childNodes: JList[ExpressionNode], rowConstructAttributes: Seq[Attribute], - withNull: Boolean = true): ScalarFunctionNode = { + aggFunc: AggregateFunction): ScalarFunctionNode = { val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName( - if (withNull) "row_constructor_with_null" else "row_constructor", + VeloxIntermediateData.getRowConstructFuncName(aggFunc), rowConstructAttributes.map(attr => attr.dataType)) val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) @@ -367,208 +353,60 @@ case class HashAggregateExecTransformer( }) for (aggregateExpression <- aggregateExpressions) { - val functionInputAttributes = aggregateExpression.aggregateFunction.inputAggBufferAttributes - val aggregateFunction = aggregateExpression.aggregateFunction - aggregateFunction match { + val aggFunc = aggregateExpression.aggregateFunction + val functionInputAttributes = aggFunc.inputAggBufferAttributes + aggFunc match { case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial => - val childNodes = new JArrayList[ExpressionNode]( - aggregateFunction.children - .map( - attr => { - ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) - }) - .asJava) + val childNodes = aggFunc.children + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args) + ) + .asJava exprNodes.addAll(childNodes) - case avg: Average => - aggregateExpression.mode match { - case PartialMerge | Final => - assert( - functionInputAttributes.size == 2, - s"${aggregateExpression.mode.toString} of Average expects two input attributes.") - // Use a Velox function to combine the intermediate columns into struct. - val childNodes = - functionInputAttributes.toList - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args)) - .asJava - exprNodes.add( - getRowConstructNode( - args, - childNodes, - functionInputAttributes, - withNull = !avg.dataType.isInstanceOf[DecimalType])) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case _: First | _: Last | _: MaxMinBy => - aggregateExpression.mode match { - case PartialMerge | Final => - assert( - functionInputAttributes.size == 2, - s"${aggregateExpression.mode.toString} of " + - s"${aggregateFunction.getClass.toString} expects two input attributes.") - // Use a Velox function to combine the intermediate columns into struct. - val childNodes = functionInputAttributes.toList - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args) - ) - .asJava - exprNodes.add(getRowConstructNode(args, childNodes, functionInputAttributes)) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => + + case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 => + throw new UnsupportedOperationException("Only one input attribute is expected.") + + case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) => + // The process of handling the inconsistency in column types and order between + // Spark and Velox is exactly the opposite of applyExtractStruct. aggregateExpression.mode match { case PartialMerge | Final => - assert( - functionInputAttributes.size == 3, - s"${aggregateExpression.mode.toString} mode of" + - s"${aggregateFunction.getClass.toString} expects three input attributes." - ) - // Use a Velox function to combine the intermediate columns into struct. - var index = 0 - var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = functionInputAttributes.toList.map { - attr => - val aggExpr: ExpressionTransformer = ExpressionConverter + val newInputAttributes = new ArrayBuffer[Attribute]() + val childNodes = new JArrayList[ExpressionNode]() + val (sparkOrders, sparkTypes) = + aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip + val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc) + val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_)) + veloxTypes.zipWithIndex.foreach { + case (veloxType, idx) => + val sparkType = sparkTypes(adjustedOrders(idx)) + val attr = functionInputAttributes(adjustedOrders(idx)) + val aggFuncInputAttrNode = ExpressionConverter .replaceWithExpressionTransformer(attr, originalInputAttributes) - val aggNode = aggExpr.doTransform(args) - val expressionNode = if (index == 0) { - // Cast count from DoubleType into LongType to align with Velox semantics. - newInputAttributes = newInputAttributes :+ - attr.copy(attr.name, LongType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier) + .doTransform(args) + val expressionNode = if (sparkType != veloxType) { + newInputAttributes += + attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier) ExpressionBuilder.makeCast( - ConverterUtils.getTypeNode(LongType, attr.nullable), - aggNode, + ConverterUtils.getTypeNode(veloxType, attr.nullable), + aggFuncInputAttrNode, SQLConf.get.ansiEnabled) } else { - newInputAttributes = newInputAttributes :+ attr - aggNode + newInputAttributes += attr + aggFuncInputAttrNode } - index += 1 - expressionNode - }.asJava - exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes)) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case _: Corr => - aggregateExpression.mode match { - case PartialMerge | Final => - assert( - functionInputAttributes.size == 6, - s"${aggregateExpression.mode.toString} mode of Corr expects 6 input attributes.") - // Use a Velox function to combine the intermediate columns into struct. - var index = 0 - var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = new JArrayList[ExpressionNode]() - // Velox's Corr order is [ck, n, xMk, yMk, xAvg, yAvg] - // Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk] - val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) - val veloxInputOrder = - VeloxIntermediateData.veloxCorrIntermediateDataOrder.map( - name => sparkCorrOutputAttr.indexOf(name)) - for (order <- veloxInputOrder) { - val attr = functionInputAttributes(order) - val aggExpr: ExpressionTransformer = ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - val aggNode = aggExpr.doTransform(args) - val expressionNode = if (order == 0) { - // Cast count from DoubleType into LongType to align with Velox semantics. - newInputAttributes = newInputAttributes :+ - attr.copy(attr.name, LongType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier) - ExpressionBuilder.makeCast( - ConverterUtils.getTypeNode(LongType, attr.nullable), - aggNode, - SQLConf.get.ansiEnabled) - } else { - newInputAttributes = newInputAttributes :+ attr - aggNode - } - index += 1 - childNodes.add(expressionNode) + childNodes.add(expressionNode) } - exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes)) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case _: CovPopulation | _: CovSample => - aggregateExpression.mode match { - case PartialMerge | Final => - assert( - functionInputAttributes.size == 4, - s"${aggregateExpression.mode.toString} mode of" + - s"${aggregateFunction.getClass.toString} expects 4 input attributes.") - // Use a Velox function to combine the intermediate columns into struct. - var index = 0 - var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = new JArrayList[ExpressionNode]() - // Velox's Covar order is [ck, n, xAvg, yAvg] - // Spark's Covar order is [n, xAvg, yAvg, ck] - val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) - val veloxInputOrder = - VeloxIntermediateData.veloxCovarIntermediateDataOrder.map( - name => sparkCorrOutputAttr.indexOf(name)) - for (order <- veloxInputOrder) { - val attr = functionInputAttributes(order) - val aggExpr: ExpressionTransformer = ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - val aggNode = aggExpr.doTransform(args) - val expressionNode = if (order == 0) { - // Cast count from DoubleType into LongType to align with Velox semantics. - newInputAttributes = newInputAttributes :+ - attr.copy(attr.name, LongType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier) - ExpressionBuilder.makeCast( - ConverterUtils.getTypeNode(LongType, attr.nullable), - aggNode, - SQLConf.get.ansiEnabled) - } else { - newInputAttributes = newInputAttributes :+ attr - aggNode - } - index += 1 - childNodes.add(expressionNode) - } - exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes)) - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - aggregateExpression.mode match { - case PartialMerge | Final => - assert( - functionInputAttributes.size == 2, - "Final stage of Average expects two input attributes.") - // Use a Velox function to combine the intermediate columns into struct. - val childNodes = functionInputAttributes.toList - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args) - ) - .asJava - exprNodes.add( - getRowConstructNode(args, childNodes, functionInputAttributes, withNull = false)) + exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc)) case other => throw new UnsupportedOperationException(s"$other is not supported.") } + case _ => - if (functionInputAttributes.size != 1) { - throw new UnsupportedOperationException("Only one input attribute is expected.") - } - val childNodes = functionInputAttributes.toList + val childNodes = functionInputAttributes .map( ExpressionConverter .replaceWithExpressionTransformer(_, originalInputAttributes) @@ -602,11 +440,11 @@ case class HashAggregateExecTransformer( // Create aggregation rel. val groupingList = new JArrayList[ExpressionNode]() var colIdx = 0 - groupingExpressions.foreach( - _ => { + groupingExpressions.foreach { + _ => groupingList.add(ExpressionBuilder.makeSelection(colIdx)) colIdx += 1 - }) + } val aggFilterList = new JArrayList[ExpressionNode]() val aggregateFunctionList = new JArrayList[AggregateFunctionNode]() @@ -619,40 +457,25 @@ case class HashAggregateExecTransformer( aggFilterList.add(null) } - val aggregateFunc = aggExpr.aggregateFunction + val aggFunc = aggExpr.aggregateFunction val childrenNodes = new JArrayList[ExpressionNode]() - aggregateFunc match { - case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy - if aggExpr.mode == PartialMerge | aggExpr.mode == Final => + aggExpr.mode match { + case PartialMerge | Final => // Only occupies one column due to intermediate results are combined // by previous projection. childrenNodes.add(ExpressionBuilder.makeSelection(colIdx)) colIdx += 1 - case sum: Sum - if sum.dataType.isInstanceOf[DecimalType] && - (aggExpr.mode == PartialMerge | aggExpr.mode == Final) => - childrenNodes.add(ExpressionBuilder.makeSelection(colIdx)) - colIdx += 1 - case _ if aggExpr.mode == PartialMerge | aggExpr.mode == Final => - aggregateFunc.inputAggBufferAttributes.toList.map( - _ => { - childrenNodes.add(ExpressionBuilder.makeSelection(colIdx)) - colIdx += 1 - aggExpr - }) - case _ if aggExpr.mode == Partial => - aggregateFunc.children.toList.map( - _ => { + case Partial => + aggFunc.children.foreach { + _ => childrenNodes.add(ExpressionBuilder.makeSelection(colIdx)) colIdx += 1 - aggExpr - }) - case function => + } + case _ => throw new UnsupportedOperationException( - s"$function of ${aggExpr.mode.toString} is not supported.") + s"$aggFunc of ${aggExpr.mode.toString} is not supported.") } - addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList) + addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList) }) RelBuilder.makeAggregateRel( projectRel, diff --git a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala index 7c4d5ecc00a7d..03523a6c43b0f 100644 --- a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala +++ b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala @@ -102,6 +102,11 @@ object VeloxIntermediateData { TypeBuilder.makeStruct(false, structTypeNodes.asJava) } + def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match { + case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor" + case _ => "row_constructor_with_null" + } + object Type { /**