From b30edcb6601ccb3e278562258f715c9d50db06d6 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 5 Nov 2024 09:47:50 +0800 Subject: [PATCH] unity agg output --- .../clickhouse/CHSparkPlanExecApi.scala | 28 +++++++++++++-- .../CHHashAggregateExecTransformer.scala | 36 +++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ba165d936eed..e17e15b04bd7 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -158,16 +158,38 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan): HashAggregateExecBaseTransformer = - CHHashAggregateExecTransformer( + child: SparkPlan): HashAggregateExecBaseTransformer = { + logError(s"xxx aggregateExpressions:$aggregateExpressions") + logError(s"xxx aggregateAttributes:$aggregateAttributes") + logError(s"xxx resultExpressions:$resultExpressions") + logError(s"xxx agg expr to result: ${aggregateExpressions.map(_.resultAttribute)}") + logError( + s"xxx agg:" + + s"${aggregateExpressions.map(e => e.aggregateFunction.aggBufferAttributes.length)}") + aggregateExpressions.foreach { + e => logError(s"xxx agg fun:$e, ${e.aggregateFunction.aggBufferAttributes}") + } + val modes = aggregateExpressions.map(_.mode) + logError(s"xxx modes: $modes") + val xoutputs = CHHashAggregateExecTransformer.getCHAggregateResultAttributes( + aggregateExpressions, + resultExpressions.slice(groupingExpressions.length, resultExpressions.length)) + logError(s"xxx adjust agg output: $xoutputs") + val replacedResultExpressions = + groupingExpressions ++ xoutputs + val agg = CHHashAggregateExecTransformer( requiredChildDistributionExpressions, groupingExpressions.distinct, aggregateExpressions, aggregateAttributes, initialInputBufferOffset, - resultExpressions.distinct, + // resultExpressions.distinct, + replacedResultExpressions, child ) + logError(s"xxx agg output: ${agg.output}") + agg + } /** Generate HashAggregateExecPullOutHelper */ override def genHashAggregateExecPullOutHelper( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index f5e64330cd15..8a0b66fcf191 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -43,6 +43,42 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer object CHHashAggregateExecTransformer { + // The result attributes of aggregate expressions from vanilla may be different from CH native. + // For example, the result attributes of `avg(x)` are `sum(x)` and `count(x)`. This could bring + // some unexpected issues. So we need to make the result attributes consistent with CH native. + def getCHAggregateResultAttributes( + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression]): Seq[Attribute] = { + var resultExpressionIndex = 0 + aggregateExpressions.flatMap { + aggExpr => + aggExpr.mode match { + case Partial | PartialMerge => + val aggBufferAttributesCount = aggExpr.aggregateFunction.aggBufferAttributes.length + aggExpr.aggregateFunction match { + case avg: Average => + val res = Seq(aggExpr.resultAttribute) + resultExpressionIndex += aggBufferAttributesCount + res + case sum: Sum if (sum.dataType.isInstanceOf[DecimalType]) => + val res = Seq(resultExpressions(resultExpressionIndex).toAttribute) + resultExpressionIndex += aggBufferAttributesCount + res + case _ => + val res = resultExpressions + .slice(resultExpressionIndex, resultExpressionIndex + aggBufferAttributesCount) + .map(_.toAttribute) + resultExpressionIndex += aggBufferAttributesCount + res + } + case _ => + val res = Seq(resultExpressions(resultExpressionIndex).toAttribute) + resultExpressionIndex += 1 + res + } + } + } + def getAggregateResultAttributes( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression]): Seq[Attribute] = {