diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala index 6ac39bb3f1a4b..97159378585b0 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala @@ -136,7 +136,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP /** * Find a Gluten plan whose child is QueryStageExec. Then, increase stageFallbackCost if the - * last query stage's output is columnar. + * last query stage's plan is GlutenPlan and decrease stageFallbackCost if not. */ def countStageFallbackCostInternal(plan: SparkPlan): Unit = { plan match { @@ -144,8 +144,14 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP p.children .filter(_.isInstanceOf[QueryStageExec]) .foreach { - case stage: QueryStageExec if stage.supportsColumnar => + case stage: QueryStageExec + if stage.plan.isInstanceOf[GlutenPlan] || + InMemoryTableScanHelper.isGlutenTableCache(stage) => stageFallbackCost = stageFallbackCost + 1 + // For other cases, RowToColumnar will be removed if stage falls back, so reduce + // the cost. + case _ => + stageFallbackCost = stageFallbackCost - 1 } case p => p.children.foreach(countStageFallbackCostInternal) } @@ -174,7 +180,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP GlutenConfig.getConf.wholeStageFallbackThreshold } else if (plan.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined) { // if we are here, that means we are now at `QueryExecution.preparations` and - // AQE is actually applied. We do nothing for this case, and later in + // AQE is actually not applied. We do nothing for this case, and later in // AQE we can check `wholeStageFallbackThreshold`. return None } else {