Skip to content

Commit

Permalink
[CORE] Optimize some methods in agg transformer (#3564)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Nov 6, 2023
1 parent 7185955 commit ba045c7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,14 @@ case class HashAggregateExecTransformer(
// Return whether the outputs partial aggregation should be combined for Velox computing.
// When the partial outputs are multiple-column, row construct is needed.
private def rowConstructNeeded: Boolean = {
for (aggregateExpression <- aggregateExpressions) {
aggregateExpression.mode match {
case PartialMerge | Final =>
if (aggregateExpression.aggregateFunction.inputAggBufferAttributes.size > 1) {
return true
}
case _ =>
}
aggregateExpressions.exists {
aggExpr =>
aggExpr.mode match {
case PartialMerge | Final =>
aggExpr.aggregateFunction.inputAggBufferAttributes.size > 1
case _ => false
}
}
false
}

// Return a scalar function node representing row construct function in Velox.
Expand Down Expand Up @@ -807,14 +805,8 @@ case class HashAggregateExecTransformer(
* whether partial and partial-merge functions coexist.
*/
def mixedPartialAndMerge: Boolean = {
val partialMergeExists = aggregateExpressions.exists(
expression => {
expression.mode == PartialMerge
})
val partialExists = aggregateExpressions.exists(
expression => {
expression.mode == Partial
})
val partialMergeExists = aggregateExpressions.exists(_.mode == PartialMerge)
val partialExists = aggregateExpressions.exists(_.mode == Partial)
partialMergeExists && partialExists
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import com.google.protobuf.Any
import java.util

import scala.collection.mutable.ListBuffer
import scala.util.control.Breaks.{break, breakable}

/** Columnar Based HashAggregateExec. */
abstract class HashAggregateExecBaseTransformer(
Expand Down Expand Up @@ -159,74 +158,45 @@ abstract class HashAggregateExecBaseTransformer(
// Members declared in org.apache.spark.sql.execution.AliasAwareOutputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

// Check if Pre-Projection is needed before the Aggregation.
protected def needsPreProjection: Boolean = {
var needsProjection = false
breakable {
for (expr <- groupingExpressions) {
if (!expr.isInstanceOf[Attribute]) {
needsProjection = true
break
}
}
}
breakable {
for (expr <- aggregateExpressions) {
if (
expr.filter.isDefined && !expr.filter.get.isInstanceOf[Attribute] &&
!expr.filter.get.isInstanceOf[Literal]
) {
needsProjection = true
break
groupingExpressions.exists {
case _: Attribute => false
case _ => true
} || aggregateExpressions.exists {
expr =>
expr.filter match {
case None | Some(_: Attribute) | Some(_: Literal) =>
case _ => return true
}
expr.mode match {
case Partial =>
for (aggChild <- expr.aggregateFunction.children) {
if (!aggChild.isInstanceOf[Attribute] && !aggChild.isInstanceOf[Literal]) {
needsProjection = true
break
}
expr.aggregateFunction.children.exists {
case _: Attribute | _: Literal => false
case _ => true
}
// No need to consider pre-projection for PartialMerge and Final Agg.
case _ =>
case _ => false
}
}
}
needsProjection
}

// Check if Post-Projection is needed after the Aggregation.
protected def needsPostProjection(aggOutAttributes: List[Attribute]): Boolean = {
// Check if Post-Projection is needed after the Aggregation.
var needsProjection = false
// If the result expressions has different size with output attribute,
// post-projection is needed.
if (resultExpressions.size != aggOutAttributes.size) {
needsProjection = true
} else {
// Compare each item in result expressions and output attributes.
breakable {
for (exprIdx <- resultExpressions.indices) {
resultExpressions(exprIdx) match {
case exprAttr: Attribute =>
val resAttr = aggOutAttributes(exprIdx)
// If the result attribute and result expression has different name or type,
// post-projection is needed.
if (
exprAttr.name != resAttr.name ||
exprAttr.dataType != resAttr.dataType
) {
needsProjection = true
break
}
case _ =>
// If result expression is not instance of Attribute,
// post-projection is needed.
needsProjection = true
break
}
}
}
resultExpressions.size != aggOutAttributes.size ||
// Compare each item in result expressions and output attributes.
resultExpressions.zip(aggOutAttributes).exists {
case (exprAttr: Attribute, resAttr) =>
// If the result attribute and result expression has different name or type,
// post-projection is needed.
exprAttr.name != resAttr.name || exprAttr.dataType != resAttr.dataType
case _ =>
// If result expression is not instance of Attribute,
// post-projection is needed.
true
}
needsProjection
}

protected def getAggRelWithPreProjection(
Expand Down Expand Up @@ -738,19 +708,5 @@ abstract class HashAggregateExecBaseTransformer(
operatorId: Long,
aggParams: AggregationParams,
input: RelNode = null,
validation: Boolean = false): RelNode = {
val originalInputAttributes = child.output
val aggRel = if (needsPreProjection) {
getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation)
} else {
getAggRelWithoutPreProjection(context, originalInputAttributes, operatorId, input, validation)
}
// Will check if post-projection is needed. If yes, a ProjectRel will be added after the
// AggregateRel.
if (!needsPostProjection(allAggregateResultAttributes)) {
aggRel
} else {
applyPostProjection(context, aggRel, operatorId, validation)
}
}
validation: Boolean = false): RelNode
}

0 comments on commit ba045c7

Please sign in to comment.