Skip to content

Commit

Permalink
Refine the code
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Oct 31, 2023
1 parent 1c6826c commit 8dfd6cf
Showing 1 changed file with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,49 +67,93 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan)
extends Rule[SparkPlan] {

private def countColumnarToRowWhenFallbackStage(plan: SparkPlan): Int = {
plan
.collectLeaves()
.filter(
p =>
p match {
case q: QueryStageExec if q.supportsColumnar => true
case _ => false
})
.size
}

private def countFallbacks(plan: SparkPlan): Int = {
private def countFallback(plan: SparkPlan): Int = {
var fallbacks = 0
def countFallback(plan: SparkPlan): Unit = {
def countFallbackInternal(plan: SparkPlan): Unit = {
plan match {
case _: QueryStageExec => // Another stage.
case _: CommandResultExec | _: ExecutedCommandExec => // ignore
// we plan exchange to columnar exchange in columnar rules and the exchange does not
// support columnar, so the output columnar is always false in AQE postStageCreationRules
case ColumnarToRowExec(s: Exchange) if isAdaptiveContext =>
countFallback(s)
countFallbackInternal(s)
case u: UnaryExecNode
if !u.isInstanceOf[GlutenPlan] && InMemoryTableScanHelper.isGlutenTableCache(u.child) =>
// Vanilla Spark plan will call `InMemoryTableScanExec.convertCachedBatchToInternalRow`
// which is a kind of `ColumnarToRowExec`.
fallbacks = fallbacks + 1
countFallback(u.child)
countFallbackInternal(u.child)
case ColumnarToRowExec(p: GlutenPlan) =>
logDebug(s"Find a columnar to row for gluten plan:\n$p")
fallbacks = fallbacks + 1
countFallback(p)
countFallbackInternal(p)
case leafPlan: LeafExecNode if InMemoryTableScanHelper.isGlutenTableCache(leafPlan) =>
case leafPlan: LeafExecNode if !leafPlan.isInstanceOf[GlutenPlan] =>
// Possible fallback for leaf node.
fallbacks = fallbacks + 1
case p => p.children.foreach(countFallback)
case p => p.children.foreach(countFallbackInternal)
}
}
countFallback(plan)
countFallbackInternal(plan)
fallbacks
}

/**
* When making a stage fall back, it's possible that we need a ColumnarToRow to adapt to last
* stage's columnar output. So we need to evaluate the cost, i.e., the number of required
* ColumnarToRow between entirely fallback stage and last stage(s). Thus, we can avoid possible
* performance degradation caused by fallback policy.
*
* spotless:off
*
* Spark plan before applying fallback policy:
*
* ColumnarExchange
* ----------- | --------------- last stage
* HashAggregateTransformer
* |
* ColumnarToRow
* |
* Project
*
* To illustrate the effect if cost is not taken into account, here is spark plan
* after applying whole stage fallback policy (threshold = 1):
*
* ColumnarExchange
* ----------- | --------------- last stage
* ColumnarToRow
* |
* HashAggregate
* |
* Project
*
* So by considering the cost, we will not apply the fallback policy.
*
* spotless:on
*/
private def countStageFallbackCost(plan: SparkPlan): Int = {
var stageFallbackCost = 0

/**
* Find a Gluten plan whose child is QueryStageExec. Then, increase stageFallbackCost if the
* last query stage's output is columnar.
*/
def countStageFallbackCostInternal(plan: SparkPlan): Unit = {
plan match {
case p: GlutenPlan if p.children.find(_.isInstanceOf[QueryStageExec]).isDefined =>
p.children
.filter(_.isInstanceOf[QueryStageExec])
.foreach {
case stage: QueryStageExec if stage.supportsColumnar =>
stageFallbackCost = stageFallbackCost + 1
}
case p => p.children.foreach(countStageFallbackCostInternal)
}
}
countStageFallbackCostInternal(plan)
stageFallbackCost
}

private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = {
def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true
Expand Down Expand Up @@ -146,11 +190,15 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
return None
}

val netFallbackNum = countFallbacks(plan) - countColumnarToRowWhenFallbackStage(plan)
val netFallbackNum = if (isAdaptiveContext) {
countFallback(plan) - countStageFallbackCost(plan)
} else {
countFallback(plan)
}
if (netFallbackNum >= fallbackThreshold) {
Some(
s"Fall back the plan due to net fallback number $netFallbackNum, " +
s"threshold $fallbackThreshold")
s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " +
s"threshold: $fallbackThreshold")
} else {
None
}
Expand Down

0 comments on commit 8dfd6cf

Please sign in to comment.