Skip to content

Commit

Permalink
1028
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Oct 31, 2024
1 parent 1b29404 commit 9488005
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// It try to move the expand node after the pre-aggregate node. That is to make the plan from
// expand -> pre-aggregate -> shuffle -> final-aggregate
// to
// pre-aggregate -> expand -> shuffle -> final-aggregate
// It could reduce the overhead of pre-aggregate node.
def enableLazyAggregateExpand(): Boolean = {
SparkEnv.get.conf.getBoolean(
CHConf.runtimeConfig("enable_lazy_aggregate_expand"),
defaultValue = true
)
}

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.gluten.extension

import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -48,104 +50,172 @@ import org.apache.spark.sql.execution.exchange._
*/

case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case finalAggregate @ HashAggregateExec(
_,
_,
_,
_,
_,
_,
_,
_,
ShuffleExchangeExec(
HashPartitioning(hashExpressions, _),
HashAggregateExec(
_,
_,
_,
groupingExpressions,
aggregateExpressions,
_,
_,
resultExpressions,
ExpandExec(projections, output, child)),
_
override def apply(plan: SparkPlan): SparkPlan = {
logDebug(s"xxx enable lazy aggregate expand: ${CHBackendSettings.enableLazyAggregateExpand}")
if (!CHBackendSettings.enableLazyAggregateExpand) {
return plan
}
plan.transformUp {
case finalAggregate @ HashAggregateExec(
_,
_,
_,
_,
_,
_,
_,
_,
ShuffleExchangeExec(
HashPartitioning(hashExpressions, _),
HashAggregateExec(
_,
_,
_,
groupingExpressions,
aggregateExpressions,
_,
_,
resultExpressions,
ExpandExec(projections, output, child)),
_
)
) =>
logError(s"xxx match plan:$finalAggregate")
// move expand node after shuffle node
if (
groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
hashExpressions.forall(_.isInstanceOf[Attribute]) &&
aggregateExpressions.forall(_.filter.isEmpty)
) {
val shuffle =
finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec]
val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec]
val expand = partialAggregate.child.asInstanceOf[ExpandExec]

val attributesToReplace = buildReplaceAttributeMapForAggregate(
groupingExpressions,
projections,
output
)
) =>
// move expand node after shuffle node
if (
projections.exists(
projection =>
projection.forall(
e => !e.isInstanceOf[Literal] || e.asInstanceOf[Literal].value != null)) &&
groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
hashExpressions.forall(_.isInstanceOf[Attribute]) &&
aggregateExpressions.forall(_.filter.isEmpty)
) {
val shuffle =
finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec]
val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec]

val attributesToReplace = buildReplaceAttributeMapForAggregate(
groupingExpressions,
projections,
output
)
val newGroupingExpresion =
groupingExpressions
.filter(_.name.startsWith("spark_grouping_id") == false)
.map(e => attributesToReplace.getOrElse(e.name, e))
val newResultExpressions =
resultExpressions
.filter(_.name.startsWith("spark_grouping_id") == false)
.map(e => attributesToReplace.getOrElse(e.name, e))
val newHashExpresions =
hashExpressions
.filter(_.asInstanceOf[Attribute].name.startsWith("spark_grouping_id") == false)
.map {
e =>
e match {
case ne: NamedExpression => attributesToReplace.getOrElse(ne.name, e)
case _ => e
}
}
val newExpandProjectionTemplate =
partialAggregate.output.map(e => attributesToReplace.getOrElse(e.name, e))
val newExpandProjections = buildNewExpandProjections(
groupingExpressions,
projections,
output,
newExpandProjectionTemplate
)
val newPartialAggregate = partialAggregate.copy(
groupingExpressions = newGroupingExpresion,
resultExpressions = newResultExpressions,
child = child
)
val newExpand =
ExpandExec(newExpandProjections, partialAggregate.output, newPartialAggregate)
val newShuffle = shuffle.copy(child = newExpand)
finalAggregate.copy(child = newShuffle)
} else {
finalAggregate
}
logError(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
expand,
partialAggregate,
newPartialAggregate,
attributesToReplace
)

val newShuffle = shuffle.copy(child = newExpand)
val newFinalAggregate = finalAggregate.copy(child = newShuffle)
logError(s"xxx new plan: $newFinalAggregate")
newFinalAggregate
} else {
finalAggregate
}
case finalAggregate @ HashAggregateExec(
_,
_,
_,
_,
_,
_,
_,
_,
ShuffleExchangeExec(
HashPartitioning(hashExpressions, _),
HashAggregateExec(
_,
_,
_,
groupingExpressions,
aggregateExpressions,
_,
_,
resultExpressions,
FilterExec(_, ExpandExec(projections, output, child))),
_
)
) =>
logError(s"xxx match plan:$finalAggregate")
if (
groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
hashExpressions.forall(_.isInstanceOf[Attribute]) &&
aggregateExpressions.forall(_.filter.isEmpty)
) {
val shuffle =
finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec]
val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec]
val filter = partialAggregate.child.asInstanceOf[FilterExec]
val expand = filter.child.asInstanceOf[ExpandExec]

val attributesToReplace = buildReplaceAttributeMapForAggregate(
groupingExpressions,
projections,
output
)
logDebug(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
expand,
partialAggregate,
newPartialAggregate,
attributesToReplace
)

val newFilter = filter.copy(child = newExpand)
val newShuffle = shuffle.copy(child = newFilter)
val newFinalAggregate = finalAggregate.copy(child = newShuffle)
logError(s"xxx new plan: $newFinalAggregate")
newFinalAggregate
} else {
finalAggregate
}
}
}

def getReplaceAttribute(
toReplace: Attribute,
attributesToReplace: Map[Attribute, Attribute]): Attribute = {
attributesToReplace.getOrElse(toReplace, toReplace)
}

def buildReplaceAttributeMapForAggregate(
originalGroupingExpressions: Seq[NamedExpression],
originalExpandProjections: Seq[Seq[Expression]],
originalExpandOutput: Seq[Attribute]): Map[String, Attribute] = {
val fullExpandProjection = originalExpandProjections
.filter(
projection =>
projection.forall(
e => !e.isInstanceOf[Literal] || e.asInstanceOf[Literal].value != null))(0)
var attributeMap = Map[String, Attribute]()
originalGroupingExpressions.filter(_.name.startsWith("spark_grouping_id") == false).foreach {
originalExpandOutput: Seq[Attribute]): Map[Attribute, Attribute] = {

var fullExpandProjection = Seq[Expression]()
for (i <- 0 until originalExpandProjections(0).length) {
val attr = originalExpandProjections.find(x => x(i).isInstanceOf[Attribute]) match {
case Some(projection) => projection(i).asInstanceOf[Attribute]
case None => null
}
fullExpandProjection = fullExpandProjection :+ attr
}
var attributeMap = Map[Attribute, Attribute]()
val groupIdAttribute = originalExpandOutput(originalExpandOutput.length - 1)
originalGroupingExpressions.filter(_.toAttribute != groupIdAttribute).foreach {
e =>
val index = originalExpandOutput.indexWhere(_.semanticEquals(e.toAttribute))
attributeMap += (e.name -> fullExpandProjection(index).asInstanceOf[Attribute])
val attr = fullExpandProjection(index).asInstanceOf[Attribute]
// if the grouping key is a literal, cast it to Attribute will be null
if (attr != null) {
attributeMap += (e.toAttribute -> attr)
}
// attributeMap +=(e.toAttribute -> fullExpandProjection(index).asInstanceOf[Attribute])
}
attributeMap
}
Expand All @@ -172,14 +242,80 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo
projection =>
val res = newExpandOutput.map {
attr =>
groupingKeysPosition.get(attr.name) match {
case Some(attrPos) => projection(attrPos)
case None => attr
if (attr.isInstanceOf[Attribute]) {
groupingKeysPosition.get(attr.name) match {
case Some(attrPos) => projection(attrPos)
case None => attr
}
} else {
attr
}
}
res
}
newExpandProjections
}

// Make the child of expand be the child of aggregate
// Need to replace some attributes
def buildNewAggregateExec(
partialAggregate: HashAggregateExec,
expand: ExpandExec,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
val expandOutput = expand.output
// As far as know, the last attribute in the output is the groupId attribute.
val groupIdAttribute = expandOutput(expandOutput.length - 1)

// if the grouping keys contains literal, they should not be in attributesToReplace
// And we need to remove them from the grouping keys
val groupingExpressions = partialAggregate.groupingExpressions
val newGroupingExpresion =
groupingExpressions
.filter(
e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute))
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
logError(
s"xxx newGroupingExpresion: $newGroupingExpresion,\n" +
s"groupingExpressions: $groupingExpressions")

// Also need to remove literal grouping keys from the result expressions
val resultExpressions = partialAggregate.resultExpressions
val newResultExpressions =
resultExpressions
.filter {
e =>
e.toAttribute != groupIdAttribute && (groupingExpressions
.find(_.toAttribute == e.toAttribute)
.isEmpty || attributesToReplace.contains(e.toAttribute))
}
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
logError(
s"xxx newResultExpressions: $newResultExpressions\n" +
s"resultExpressions:$resultExpressions")
partialAggregate.copy(
groupingExpressions = newGroupingExpresion,
resultExpressions = newResultExpressions,
child = expand.child)
}

def buildNewExpandExec(
expand: ExpandExec,
partialAggregate: HashAggregateExec,
child: SparkPlan,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
val newExpandProjectionTemplate =
partialAggregate.output
.map(e => getReplaceAttribute(e, attributesToReplace))
logError(s"xxx newExpandProjectionTemplate: $newExpandProjectionTemplate")

val newExpandProjections = buildNewExpandProjections(
partialAggregate.groupingExpressions,
expand.projections,
expand.output,
newExpandProjectionTemplate
)
logError(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}")
ExpandExec(newExpandProjections, partialAggregate.output, child)
}

}
Loading

0 comments on commit 9488005

Please sign in to comment.