From 32663610df61d70141a836718eb21f3afb1b12a2 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 5 Nov 2024 09:47:50 +0800 Subject: [PATCH] unity agg output --- .../clickhouse/CHSparkPlanExecApi.scala | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ba165d936eed..3508f908f823 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -158,16 +158,36 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan): HashAggregateExecBaseTransformer = - CHHashAggregateExecTransformer( + child: SparkPlan): HashAggregateExecBaseTransformer = { + logError(s"xxx aggregateExpressions:$aggregateExpressions") + logError(s"xxx aggregateAttributes:$aggregateAttributes") + logError(s"xxx resultExpressions:$resultExpressions") + logError(s"xxx agg expr to result: ${aggregateExpressions.map(_.resultAttribute)}") + logError( + s"xxx agg:" + + s"${aggregateExpressions.map(e => e.aggregateFunction.aggBufferAttributes.length)}") + aggregateExpressions.foreach { + e => logError(s"xxx agg fun:$e, ${e.aggregateFunction.aggBufferAttributes}") + } + val replacedResultExpressions = + groupingExpressions ++ aggregateExpressions.map(_.resultAttribute) + val agg = CHHashAggregateExecTransformer( requiredChildDistributionExpressions, groupingExpressions.distinct, aggregateExpressions, aggregateAttributes, initialInputBufferOffset, - resultExpressions.distinct, + // resultExpressions.distinct, + replacedResultExpressions, child ) + // val xoutputs = CHHashAggregateExecTransformer.getCHAggregateResultAttributes( + // aggregateExpressions, + // resultExpressions.slice(groupingExpressions.length, resultExpressions.length)) + // logError(s"xxx adjust agg output: $xoutputs") + logError(s"xxx agg output: ${agg.output}") + agg + } /** Generate HashAggregateExecPullOutHelper */ override def genHashAggregateExecPullOutHelper(