Skip to content

Commit

Permalink
[VL] Declare IntermediateTypes for specific agg function (#3679)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Nov 13, 2023
1 parent 70061ac commit 7cf715b
Showing 1 changed file with 35 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package io.glutenproject.execution

import io.glutenproject.execution.VeloxAggregateFunctionsBuilder.{veloxFourIntermediateTypes, veloxSixIntermediateTypes, veloxThreeIntermediateTypes}
import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
Expand Down Expand Up @@ -217,73 +217,33 @@ case class HashAggregateExecTransformer(
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = new JArrayList[TypeNode]()
aggregateFunction match {
val structTypeNodes = aggregateFunction match {
case avg: Average =>
structTypeNodes.add(
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
case first: First =>
structTypeNodes.add(ConverterUtils.getTypeNode(first.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true))
ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case last: Last =>
structTypeNodes.add(ConverterUtils.getTypeNode(last.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true))
ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case maxMinBy: MaxMinBy =>
structTypeNodes
.add(ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true))
structTypeNodes
.add(ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true))
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxThreeIntermediateTypes.head, nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxThreeIntermediateTypes(1), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxThreeIntermediateTypes(2), nullable = false))
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes.head, nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes(1), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes(2), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes(3), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes(4), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxSixIntermediateTypes(5), nullable = false))
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxFourIntermediateTypes.head, nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxFourIntermediateTypes(1), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxFourIntermediateTypes(2), nullable = false))
structTypeNodes.add(
ConverterUtils
.getTypeNode(veloxFourIntermediateTypes(3), nullable = false))
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
structTypeNodes.add(ConverterUtils.getTypeNode(sum.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = false))
ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
TypeBuilder.makeStruct(false, structTypeNodes)
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
Expand Down Expand Up @@ -853,9 +813,9 @@ object VeloxAggregateFunctionsBuilder {
val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

val veloxThreeIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxFourIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxSixIntermediateTypes: Seq[DataType] =
val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
Expand All @@ -871,37 +831,24 @@ object VeloxAggregateFunctionsBuilder {
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(child => child.dataType)
}
if (aggregateFunc.aggBufferAttributes.size == veloxThreeIntermediateTypes.size) {
return Seq(
StructType(
veloxThreeIntermediateTypes
.map(intermediateType => StructField("", intermediateType))
.toArray))
}
if (aggregateFunc.aggBufferAttributes.size == veloxFourIntermediateTypes.size) {
return Seq(
StructType(
veloxFourIntermediateTypes
.map(intermediateType => StructField("", intermediateType))
.toArray))
return aggregateFunc.children.map(_.dataType)
}
if (aggregateFunc.aggBufferAttributes.size == veloxSixIntermediateTypes.size) {
return Seq(
StructType(
veloxSixIntermediateTypes
.map(intermediateType => StructField("", intermediateType))
.toArray))
}
if (aggregateFunc.aggBufferAttributes.size > 1) {
return Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
aggregateFunc.aggBufferAttributes.map(child => child.dataType)
}

/**
Expand Down

0 comments on commit 7cf715b

Please sign in to comment.