Skip to content

Commit

Permalink
Merge branch 'main' into debug
Browse files Browse the repository at this point in the history
  • Loading branch information
yaooqinn authored Nov 21, 2023
2 parents b43ba0f + f29077e commit 0bd934e
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
*/
package io.glutenproject.execution

import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.GlutenDecimalUtil
import io.glutenproject.utils.VeloxIntermediateData

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,27 +79,18 @@ case class HashAggregateExecTransformer(
* @return
* extracting needed or not.
*/
def extractStructNeeded(): Boolean = {
for (expr <- aggregateExpressions) {
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case _ =>
}
private def extractStructNeeded(): Boolean = {
aggregateExpressions.exists {
expr =>
expr.aggregateFunction match {
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
expr.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}
false
}

/**
Expand Down Expand Up @@ -133,56 +123,29 @@ case class HashAggregateExecTransformer(
case _ =>
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
val aggFunc = expr.aggregateFunction
expr.aggregateFunction match {
case _: Average | _: First | _: Last | _: MaxMinBy =>
// Select first and second aggregate buffer from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 0),
SQLConf.get.ansiEnabled))
// Select avg from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
// Select m2 from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
colIdx += 1
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
// Select sum from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
// Select isEmpty from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: Corr =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 4))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 5))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
colIdx += 1
case _: CovPopulation | _: CovSample =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_))
sparkTypes.zipWithIndex.foreach {
case (sparkType, idx) =>
val veloxType = veloxTypes(adjustedOrders(idx))
if (veloxType != sparkType) {
// Velox and Spark have different type, adding a cast expression
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(sparkType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)),
SQLConf.get.ansiEnabled))
} else {
// Velox and Spark have the same type
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)))
}
}
colIdx += 1
case _ =>
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx))
Expand All @@ -209,43 +172,6 @@ case class HashAggregateExecTransformer(
}
}

/**
* Return the intermediate type node of a partial aggregation in Velox.
* @param aggregateFunction
* The aggregation function.
* @return
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = aggregateFunction match {
case avg: Average =>
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
case first: First =>
ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case last: Last =>
ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case maxMinBy: MaxMinBy =>
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
}
Expand All @@ -268,15 +194,16 @@ case class HashAggregateExecTransformer(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(partialNode)
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction)
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(aggFunctionNode)
case Final =>
Expand Down Expand Up @@ -356,7 +283,7 @@ case class HashAggregateExecTransformer(
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand All @@ -367,7 +294,7 @@ case class HashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand Down Expand Up @@ -547,7 +474,7 @@ case class HashAggregateExecTransformer(
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map(
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -590,7 +517,7 @@ case class HashAggregateExecTransformer(
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map(
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -810,47 +737,6 @@ case class HashAggregateExecTransformer(
/** An aggregation function builder specifically used by Velox backend. */
object VeloxAggregateFunctionsBuilder {

val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Get the compatible input types for a Velox aggregate function.
* @param aggregateFunc:
* the input aggreagate function.
* @param forMergeCompanion:
* whether this is a special case to solve mixed aggregation phases.
* @return
* the input types of a Velox aggregate function.
*/
private def getInputTypes(
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

/**
* Create an scalar function for the input aggregate function.
* @param args:
Expand Down Expand Up @@ -887,7 +773,7 @@ object VeloxAggregateFunctionsBuilder {
functionMap,
ConverterUtils.makeFuncName(
substraitAggFuncName,
getInputTypes(aggregateFunc, forMergeCompanion),
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
FunctionConfig.REQ))
}
}
Loading

0 comments on commit 0bd934e

Please sign in to comment.