Skip to content

Commit

Permalink
introduce VeloxIntermediateTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 15, 2023
1 parent 5fb2fb9 commit 55ececc
Showing 1 changed file with 48 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,18 @@ case class HashAggregateExecTransformer(
* @return
* extracting needed or not.
*/
def extractStructNeeded(): Boolean = {
for (expr <- aggregateExpressions) {
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case _ =>
}
private def extractStructNeeded(): Boolean = {
aggregateExpressions.exists {
expr =>
expr.aggregateFunction match {
case _ @VeloxIntermediateTypes(_) =>
expr.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}
false
}

/**
Expand Down Expand Up @@ -215,19 +206,14 @@ case class HashAggregateExecTransformer(
* @return
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = aggregateFunction match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case aggFunc =>
aggFunc.aggBufferAttributes.map(
attr => ConverterUtils.getTypeNode(attr.dataType, nullable = true))
}
private def getIntermediateTypeNode(aggFunc: AggregateFunction): TypeNode = {
val structTypeNodes =
aggFunc match {
case _ @VeloxIntermediateTypes(dataTypes: Seq[DataType]) =>
dataTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _ =>
throw new UnsupportedOperationException("Can not get velox intermediate types.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

Expand Down Expand Up @@ -819,23 +805,40 @@ object VeloxAggregateFunctionsBuilder {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ @VeloxIntermediateTypes(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = {
aggFunc match {
case _: Corr =>
veloxCorrIntermediateDataOrder
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateDataOrder
case _ =>
aggFunc.aggBufferAttributes.map(_.name)
}
}

object VeloxIntermediateTypes {
def unapply(aggFunc: AggregateFunction): Option[Seq[DataType]] = {
aggFunc match {
case _: Corr =>
Some(veloxCorrIntermediateTypes)
case _: CovPopulation | _: CovSample =>
Some(veloxCovarIntermediateTypes)
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Some(veloxVarianceIntermediateTypes)
case _ if aggFunc.aggBufferAttributes.size > 1 =>
Some(aggFunc.aggBufferAttributes.map(_.dataType))
case _ => None
}
}
}

/**
* Create an scalar function for the input aggregate function.
* @param args:
Expand Down

0 comments on commit 55ececc

Please sign in to comment.