From 8dfd6cfa78f55d70419075a3f3c9ad64b8daed05 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Mon, 30 Oct 2023 21:22:30 +0800 Subject: [PATCH] Refine the code --- .../extension/ExpandFallbackPolicy.scala | 92 ++++++++++++++----- 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala index 191c4a80f1d15..2c5f0825647ac 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala @@ -67,49 +67,93 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan) extends Rule[SparkPlan] { - private def countColumnarToRowWhenFallbackStage(plan: SparkPlan): Int = { - plan - .collectLeaves() - .filter( - p => - p match { - case q: QueryStageExec if q.supportsColumnar => true - case _ => false - }) - .size - } - - private def countFallbacks(plan: SparkPlan): Int = { + private def countFallback(plan: SparkPlan): Int = { var fallbacks = 0 - def countFallback(plan: SparkPlan): Unit = { + def countFallbackInternal(plan: SparkPlan): Unit = { plan match { case _: QueryStageExec => // Another stage. case _: CommandResultExec | _: ExecutedCommandExec => // ignore // we plan exchange to columnar exchange in columnar rules and the exchange does not // support columnar, so the output columnar is always false in AQE postStageCreationRules case ColumnarToRowExec(s: Exchange) if isAdaptiveContext => - countFallback(s) + countFallbackInternal(s) case u: UnaryExecNode if !u.isInstanceOf[GlutenPlan] && InMemoryTableScanHelper.isGlutenTableCache(u.child) => // Vanilla Spark plan will call `InMemoryTableScanExec.convertCachedBatchToInternalRow` // which is a kind of `ColumnarToRowExec`. fallbacks = fallbacks + 1 - countFallback(u.child) + countFallbackInternal(u.child) case ColumnarToRowExec(p: GlutenPlan) => logDebug(s"Find a columnar to row for gluten plan:\n$p") fallbacks = fallbacks + 1 - countFallback(p) + countFallbackInternal(p) case leafPlan: LeafExecNode if InMemoryTableScanHelper.isGlutenTableCache(leafPlan) => case leafPlan: LeafExecNode if !leafPlan.isInstanceOf[GlutenPlan] => // Possible fallback for leaf node. fallbacks = fallbacks + 1 - case p => p.children.foreach(countFallback) + case p => p.children.foreach(countFallbackInternal) } } - countFallback(plan) + countFallbackInternal(plan) fallbacks } + /** + * When making a stage fall back, it's possible that we need a ColumnarToRow to adapt to last + * stage's columnar output. So we need to evaluate the cost, i.e., the number of required + * ColumnarToRow between entirely fallback stage and last stage(s). Thus, we can avoid possible + * performance degradation caused by fallback policy. + * + * spotless:off + * + * Spark plan before applying fallback policy: + * + * ColumnarExchange + * ----------- | --------------- last stage + * HashAggregateTransformer + * | + * ColumnarToRow + * | + * Project + * + * To illustrate the effect if cost is not taken into account, here is spark plan + * after applying whole stage fallback policy (threshold = 1): + * + * ColumnarExchange + * ----------- | --------------- last stage + * ColumnarToRow + * | + * HashAggregate + * | + * Project + * + * So by considering the cost, we will not apply the fallback policy. + * + * spotless:on + */ + private def countStageFallbackCost(plan: SparkPlan): Int = { + var stageFallbackCost = 0 + + /** + * Find a Gluten plan whose child is QueryStageExec. Then, increase stageFallbackCost if the + * last query stage's output is columnar. + */ + def countStageFallbackCostInternal(plan: SparkPlan): Unit = { + plan match { + case p: GlutenPlan if p.children.find(_.isInstanceOf[QueryStageExec]).isDefined => + p.children + .filter(_.isInstanceOf[QueryStageExec]) + .foreach { + case stage: QueryStageExec if stage.supportsColumnar => + stageFallbackCost = stageFallbackCost + 1 + } + case p => p.children.foreach(countStageFallbackCostInternal) + } + } + countStageFallbackCostInternal(plan) + stageFallbackCost + } + private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = { def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match { case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true @@ -146,11 +190,15 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP return None } - val netFallbackNum = countFallbacks(plan) - countColumnarToRowWhenFallbackStage(plan) + val netFallbackNum = if (isAdaptiveContext) { + countFallback(plan) - countStageFallbackCost(plan) + } else { + countFallback(plan) + } if (netFallbackNum >= fallbackThreshold) { Some( - s"Fall back the plan due to net fallback number $netFallbackNum, " + - s"threshold $fallbackThreshold") + s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " + + s"threshold: $fallbackThreshold") } else { None }