Skip to content

Commit

Permalink
unity agg output
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 5, 2024
1 parent 48d312a commit b30edcb
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,38 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer =
CHHashAggregateExecTransformer(
child: SparkPlan): HashAggregateExecBaseTransformer = {
logError(s"xxx aggregateExpressions:$aggregateExpressions")
logError(s"xxx aggregateAttributes:$aggregateAttributes")
logError(s"xxx resultExpressions:$resultExpressions")
logError(s"xxx agg expr to result: ${aggregateExpressions.map(_.resultAttribute)}")
logError(
s"xxx agg:" +
s"${aggregateExpressions.map(e => e.aggregateFunction.aggBufferAttributes.length)}")
aggregateExpressions.foreach {
e => logError(s"xxx agg fun:$e, ${e.aggregateFunction.aggBufferAttributes}")
}
val modes = aggregateExpressions.map(_.mode)
logError(s"xxx modes: $modes")
val xoutputs = CHHashAggregateExecTransformer.getCHAggregateResultAttributes(
aggregateExpressions,
resultExpressions.slice(groupingExpressions.length, resultExpressions.length))
logError(s"xxx adjust agg output: $xoutputs")
val replacedResultExpressions =
groupingExpressions ++ xoutputs
val agg = CHHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions.distinct,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions.distinct,
// resultExpressions.distinct,
replacedResultExpressions,
child
)
logError(s"xxx agg output: ${agg.output}")
agg
}

/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,42 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

object CHHashAggregateExecTransformer {
// The result attributes of aggregate expressions from vanilla may be different from CH native.
// For example, the result attributes of `avg(x)` are `sum(x)` and `count(x)`. This could bring
// some unexpected issues. So we need to make the result attributes consistent with CH native.
def getCHAggregateResultAttributes(
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression]): Seq[Attribute] = {
var resultExpressionIndex = 0
aggregateExpressions.flatMap {
aggExpr =>
aggExpr.mode match {
case Partial | PartialMerge =>
val aggBufferAttributesCount = aggExpr.aggregateFunction.aggBufferAttributes.length
aggExpr.aggregateFunction match {
case avg: Average =>
val res = Seq(aggExpr.resultAttribute)
resultExpressionIndex += aggBufferAttributesCount
res
case sum: Sum if (sum.dataType.isInstanceOf[DecimalType]) =>
val res = Seq(resultExpressions(resultExpressionIndex).toAttribute)
resultExpressionIndex += aggBufferAttributesCount
res
case _ =>
val res = resultExpressions
.slice(resultExpressionIndex, resultExpressionIndex + aggBufferAttributesCount)
.map(_.toAttribute)
resultExpressionIndex += aggBufferAttributesCount
res
}
case _ =>
val res = Seq(resultExpressions(resultExpressionIndex).toAttribute)
resultExpressionIndex += 1
res
}
}
}

def getAggregateResultAttributes(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression]): Seq[Attribute] = {
Expand Down

0 comments on commit b30edcb

Please sign in to comment.