diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 0b507c06aa14..a4e11c17bf6c 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -79,27 +79,18 @@ case class HashAggregateExecTransformer( * @return * extracting needed or not. */ - def extractStructNeeded(): Boolean = { - for (expr <- aggregateExpressions) { - val aggregateFunction = expr.aggregateFunction - aggregateFunction match { - case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => - expr.mode match { - case Partial | PartialMerge => - return true - case _ => - } - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - expr.mode match { - case Partial | PartialMerge => - return true - case _ => - } - case _ => - } + private def extractStructNeeded(): Boolean = { + aggregateExpressions.exists { + expr => + expr.aggregateFunction match { + case _ @VeloxIntermediateTypes(_) => + expr.mode match { + case Partial | PartialMerge => true + case _ => false + } + case _ => false + } } - false } /** @@ -215,19 +206,14 @@ case class HashAggregateExecTransformer( * @return * The type of partial outputs. */ - private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = { - val structTypeNodes = aggregateFunction match { - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => - // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). - veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case _: Corr => - veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case _: CovPopulation | _: CovSample => - veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case aggFunc => - aggFunc.aggBufferAttributes.map( - attr => ConverterUtils.getTypeNode(attr.dataType, nullable = true)) - } + private def getIntermediateTypeNode(aggFunc: AggregateFunction): TypeNode = { + val structTypeNodes = + aggFunc match { + case _ @VeloxIntermediateTypes(dataTypes: Seq[DataType]) => + dataTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) + case _ => + throw new UnsupportedOperationException("Can not get velox intermediate types.") + } TypeBuilder.makeStruct(false, structTypeNodes.asJava) } @@ -819,23 +805,40 @@ object VeloxAggregateFunctionsBuilder { return aggregateFunc.children.map(_.dataType) } aggregateFunc match { - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => - Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray)) - case _: CovPopulation | _: CovSample => - Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray)) - case _: Corr => - Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray)) - case aggFunc if aggFunc.aggBufferAttributes.size > 1 => - Seq( - StructType( - aggregateFunc.aggBufferAttributes - .map(attribute => StructField("", attribute.dataType)) - .toArray)) + case _ @VeloxIntermediateTypes(veloxDataTypes: Seq[DataType]) => + Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray)) case _ => aggregateFunc.aggBufferAttributes.map(_.dataType) } } + def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = { + aggFunc match { + case _: Corr => + veloxCorrIntermediateDataOrder + case _: CovPopulation | _: CovSample => + veloxCovarIntermediateDataOrder + case _ => + aggFunc.aggBufferAttributes.map(_.name) + } + } + + object VeloxIntermediateTypes { + def unapply(aggFunc: AggregateFunction): Option[Seq[DataType]] = { + aggFunc match { + case _: Corr => + Some(veloxCorrIntermediateTypes) + case _: CovPopulation | _: CovSample => + Some(veloxCovarIntermediateTypes) + case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => + Some(veloxVarianceIntermediateTypes) + case _ if aggFunc.aggBufferAttributes.size > 1 => + Some(aggFunc.aggBufferAttributes.map(_.dataType)) + case _ => None + } + } + } + /** * Create an scalar function for the input aggregate function. * @param args: