From 7cf715b8ad66fecacde2cbe99879b7f1ad944110 Mon Sep 17 00:00:00 2001 From: Joey Date: Mon, 13 Nov 2023 20:50:46 +0800 Subject: [PATCH] [VL] Declare IntermediateTypes for specific agg function (#3679) --- .../HashAggregateExecTransformer.scala | 123 +++++------------- 1 file changed, 35 insertions(+), 88 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index a082a8c79560..1b8e10893a89 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -16,7 +16,7 @@ */ package io.glutenproject.execution -import io.glutenproject.execution.VeloxAggregateFunctionsBuilder.{veloxFourIntermediateTypes, veloxSixIntermediateTypes, veloxThreeIntermediateTypes} +import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._ import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} @@ -217,73 +217,33 @@ case class HashAggregateExecTransformer( * The type of partial outputs. */ private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = { - val structTypeNodes = new JArrayList[TypeNode]() - aggregateFunction match { + val structTypeNodes = aggregateFunction match { case avg: Average => - structTypeNodes.add( - ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) + ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) :: + ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil case first: First => - structTypeNodes.add(ConverterUtils.getTypeNode(first.dataType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true)) + ConverterUtils.getTypeNode(first.dataType, nullable = true) :: + ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil case last: Last => - structTypeNodes.add(ConverterUtils.getTypeNode(last.dataType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true)) + ConverterUtils.getTypeNode(last.dataType, nullable = true) :: + ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil case maxMinBy: MaxMinBy => - structTypeNodes - .add(ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true)) - structTypeNodes - .add(ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true)) + ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) :: + ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxThreeIntermediateTypes.head, nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxThreeIntermediateTypes(1), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxThreeIntermediateTypes(2), nullable = false)) + veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) case _: Corr => - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes.head, nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes(1), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes(2), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes(3), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes(4), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxSixIntermediateTypes(5), nullable = false)) + veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) case _: CovPopulation | _: CovSample => - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxFourIntermediateTypes.head, nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxFourIntermediateTypes(1), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxFourIntermediateTypes(2), nullable = false)) - structTypeNodes.add( - ConverterUtils - .getTypeNode(veloxFourIntermediateTypes(3), nullable = false)) + veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - structTypeNodes.add(ConverterUtils.getTypeNode(sum.dataType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = false)) + ConverterUtils.getTypeNode(sum.dataType, nullable = true) :: + ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil case other => throw new UnsupportedOperationException(s"$other is not supported.") } - TypeBuilder.makeStruct(false, structTypeNodes) + TypeBuilder.makeStruct(false, structTypeNodes.asJava) } override protected def modeToKeyWord(aggregateMode: AggregateMode): String = { @@ -853,9 +813,9 @@ object VeloxAggregateFunctionsBuilder { val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg") - val veloxThreeIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) - val veloxFourIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) - val veloxSixIntermediateTypes: Seq[DataType] = + val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) + val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) + val veloxCorrIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType) /** @@ -871,37 +831,24 @@ object VeloxAggregateFunctionsBuilder { aggregateFunc: AggregateFunction, forMergeCompanion: Boolean): Seq[DataType] = { if (!forMergeCompanion) { - return aggregateFunc.children.map(child => child.dataType) - } - if (aggregateFunc.aggBufferAttributes.size == veloxThreeIntermediateTypes.size) { - return Seq( - StructType( - veloxThreeIntermediateTypes - .map(intermediateType => StructField("", intermediateType)) - .toArray)) - } - if (aggregateFunc.aggBufferAttributes.size == veloxFourIntermediateTypes.size) { - return Seq( - StructType( - veloxFourIntermediateTypes - .map(intermediateType => StructField("", intermediateType)) - .toArray)) + return aggregateFunc.children.map(_.dataType) } - if (aggregateFunc.aggBufferAttributes.size == veloxSixIntermediateTypes.size) { - return Seq( - StructType( - veloxSixIntermediateTypes - .map(intermediateType => StructField("", intermediateType)) - .toArray)) - } - if (aggregateFunc.aggBufferAttributes.size > 1) { - return Seq( - StructType( - aggregateFunc.aggBufferAttributes - .map(attribute => StructField("", attribute.dataType)) - .toArray)) + aggregateFunc match { + case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => + Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray)) + case _: CovPopulation | _: CovSample => + Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray)) + case _: Corr => + Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray)) + case aggFunc if aggFunc.aggBufferAttributes.size > 1 => + Seq( + StructType( + aggregateFunc.aggBufferAttributes + .map(attribute => StructField("", attribute.dataType)) + .toArray)) + case _ => + aggregateFunc.aggBufferAttributes.map(_.dataType) } - aggregateFunc.aggBufferAttributes.map(child => child.dataType) } /**