diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index dea0d50c9da6..f7cfa215a042 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -66,28 +66,28 @@ private object CHRuleApi { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(_ => PushDownInputFileExpression.PreOffload) - injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) - injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.spark)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.spark)) injector.injectTransform(_ => RewriteSubqueryBroadcast()) - injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session)) - injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.spark)) + injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.spark)) injector.injectTransform(_ => intercept(RewriteSparkPlanRulesManager())) injector.injectTransform(_ => intercept(AddFallbackTagRule())) injector.injectTransform(_ => intercept(TransformPreOverrides())) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) - injector.injectTransform(c => intercept(RewriteTransformer.apply(c.session))) + injector.injectTransform(c => intercept(RewriteTransformer.apply(c.spark))) injector.injectTransform(_ => PushDownFilterToScan) injector.injectTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) - injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) - injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.session)) + injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.spark)) + injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.spark)) injector.injectTransform( c => intercept( SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)( - c.session))) + c.spark))) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -95,19 +95,19 @@ private object CHRuleApi { c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) // Gluten columnar: Post rules. - injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + injector.injectPost(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext())) SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() - .foreach(each => injector.injectPost(c => intercept(each(c.session)))) + .foreach(each => injector.injectPost(c => intercept(each(c.spark)))) injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf)) injector.injectTransform( c => intercept( - SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.session))) + SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.spark))) // Gluten columnar: Final rules. - injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session)) + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.spark)) + injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.spark)) injector.injectFinal(_ => RemoveFallbackTagRule()) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala index a944f55450d4..2299400f9520 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala @@ -34,13 +34,15 @@ import scala.collection.mutable // --conf spark.sql.planChangeLog.batches=all class CommonSubexpressionEliminateRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + private val glutenConf = new GlutenConfig(spark) + private var lastPlan: LogicalPlan = null override def apply(plan: LogicalPlan): LogicalPlan = { val newPlan = if ( - plan.resolved && GlutenConfig.getConf.enableGluten - && GlutenConfig.getConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan) + plan.resolved && glutenConf.enableGluten + && glutenConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan) ) { lastPlan = plan visitPlan(plan) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala index 207bb0e3a4d7..403658e5f3ab 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala @@ -36,9 +36,11 @@ import scala.util.control.Breaks.{break, breakable} // see each other during transformation. In order to prevent BroadcastExec being transformed // to columnar while BHJ fallbacks, BroadcastExec need to be tagged not transformable when applying // queryStagePrepRules. -case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extends Rule[SparkPlan] { +case class FallbackBroadcastHashJoinPrepQueryStage(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf: GlutenConfig = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - val glutenConf: GlutenConfig = GlutenConfig.getConf plan.foreach { case bhj: BroadcastHashJoinExec => val buildSidePlan = bhj.buildSide match { @@ -144,15 +146,17 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend // For similar purpose with FallbackBroadcastHashJoinPrepQueryStage, executed during applying // columnar rules. -case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPlan] { +case class FallbackBroadcastHashJoin(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf: GlutenConfig = new GlutenConfig(spark) private val enableColumnarBroadcastJoin: Boolean = - GlutenConfig.getConf.enableColumnarBroadcastJoin && - GlutenConfig.getConf.enableColumnarBroadcastExchange + glutenConf.enableColumnarBroadcastJoin && + glutenConf.enableColumnarBroadcastExchange private val enableColumnarBroadcastNestedLoopJoin: Boolean = - GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled && - GlutenConfig.getConf.enableColumnarBroadcastExchange + glutenConf.broadcastNestedLoopJoinTransformerTransformerEnabled && + glutenConf.enableColumnarBroadcastExchange override def apply(plan: SparkPlan): SparkPlan = { plan.foreachUp { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala index a10659b6d5e7..a652bf42a627 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala @@ -34,14 +34,11 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat * Note: this rule must be applied before the `PullOutPreProject` rule, because the * `PullOutPreProject` rule will modify the attributes in some cases. */ -case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) +case class MergeTwoPhasesHashBaseAggregate(spark: SparkSession) extends Rule[SparkPlan] with Logging { - val glutenConf: GlutenConfig = GlutenConfig.getConf - val scanOnly: Boolean = glutenConf.enableScanOnly - val enableColumnarHashAgg: Boolean = !scanOnly && glutenConf.enableColumnarHashAgg - val replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg + private val glutenConf: GlutenConfig = new GlutenConfig(spark) private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = { // TODO: now it can not support to merge agg which there are the filters in the aggregate exprs. @@ -59,7 +56,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) } override def apply(plan: SparkPlan): SparkPlan = { - if (!enableColumnarHashAgg) { + if (glutenConf.enableScanOnly || !glutenConf.enableColumnarHashAgg) { plan } else { plan.transformDown { @@ -111,7 +108,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) _, resultExpressions, child: SortAggregateExec) - if replaceSortAggWithHashAgg && !isStreaming && isPartialAgg(child, sortAgg) => + if glutenConf.forceToUseHashAgg && !isStreaming && isPartialAgg(child, sortAgg) => // convert to complete mode aggregate expressions val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) sortAgg.copy( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala index fa8a37ffa2be..de04ed7f3f37 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala @@ -53,11 +53,11 @@ class RewriteDateTimestampComparisonRule(spark: SparkSession) "yyyy" ) + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { if ( - plan.resolved && - GlutenConfig.getConf.enableGluten && - GlutenConfig.getConf.enableRewriteDateTimestampComparison + plan.resolved && glutenConf.enableGluten && glutenConf.enableRewriteDateTimestampComparison ) { visitPlan(plan) } else { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala index 6e8486330465..8274cf2d728d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala @@ -39,12 +39,10 @@ import org.apache.spark.sql.types._ // Optimized result is `to_date(stringType)` class RewriteToDateExpresstionRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { - if ( - plan.resolved && - GlutenConfig.getConf.enableGluten && - GlutenConfig.getConf.enableCHRewriteDateConversion - ) { + if (plan.resolved && glutenConf.enableGluten && glutenConf.enableCHRewriteDateConversion) { visitPlan(plan) } else { plan diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala index b11e1e2bb306..4ea59f659ca2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala @@ -31,12 +31,15 @@ import org.apache.spark.sql.types._ * @param spark */ case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case a: Aggregate => a.transformExpressions { case avgExpr @ AggregateExpression(avg: Average, _, _, _, _) - if GlutenConfig.getConf.enableCastAvgAggregateFunction && - GlutenConfig.getConf.enableColumnarHashAgg && + if glutenConf.enableCastAvgAggregateFunction && + glutenConf.enableColumnarHashAgg && !avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) => AggregateExpression( avg.copy(child = Cast(avg.child, DoubleType)), diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 3554bc5c9c01..5f49f16469e0 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -52,23 +52,23 @@ private object VeloxRuleApi { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(_ => PushDownInputFileExpression.PreOffload) - injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) - injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.spark)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.spark)) injector.injectTransform(_ => RewriteSubqueryBroadcast()) - injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) - injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session)) + injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.spark)) + injector.injectTransform(c => ArrowScanReplaceRule.apply(c.spark)) injector.injectTransform(_ => RewriteSparkPlanRulesManager()) injector.injectTransform(_ => AddFallbackTagRule()) injector.injectTransform(_ => TransformPreOverrides()) - injector.injectTransform(c => PartialProjectRule.apply(c.session)) + injector.injectTransform(c => PartialProjectRule.apply(c.spark)) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) - injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(c => RewriteTransformer.apply(c.spark)) injector.injectTransform(_ => PushDownFilterToScan) injector.injectTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) - injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) + injector.injectTransform(c => FlushableHashAggregateRule.apply(c.spark)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -76,15 +76,15 @@ private object VeloxRuleApi { c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) // Gluten columnar: Post rules. - injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + injector.injectPost(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext())) SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() - .foreach(each => injector.injectPost(c => each(c.session))) + .foreach(each => injector.injectPost(c => each(c.spark))) injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf)) // Gluten columnar: Final rules. - injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session)) + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.spark)) + injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.spark)) injector.injectFinal(_ => RemoveFallbackTagRule()) } @@ -92,33 +92,33 @@ private object VeloxRuleApi { // Gluten RAS: Pre rules. injector.inject(_ => RemoveTransitions) injector.inject(_ => PushDownInputFileExpression.PreOffload) - injector.inject(c => FallbackOnANSIMode.apply(c.session)) + injector.inject(c => FallbackOnANSIMode.apply(c.spark)) injector.inject(_ => RewriteSubqueryBroadcast()) - injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) - injector.inject(c => ArrowScanReplaceRule.apply(c.session)) + injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.spark)) + injector.inject(c => ArrowScanReplaceRule.apply(c.spark)) // Gluten RAS: The RAS rule. - injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar)) + injector.inject(c => EnumeratedTransform(c.spark, c.outputsColumnar)) // Gluten RAS: Post rules. injector.inject(_ => RemoveTransitions) - injector.inject(c => PartialProjectRule.apply(c.session)) + injector.inject(c => PartialProjectRule.apply(c.spark)) injector.inject(_ => RemoveNativeWriteFilesSortAndProject()) - injector.inject(c => RewriteTransformer.apply(c.session)) + injector.inject(c => RewriteTransformer.apply(c.spark)) injector.inject(_ => PushDownFilterToScan) injector.inject(_ => PushDownInputFileExpression.PostOffload) injector.inject(_ => EnsureLocalSortRequirements) injector.inject(_ => EliminateLocalSort) injector.inject(_ => CollapseProjectExecTransformer) - injector.inject(c => FlushableHashAggregateRule.apply(c.session)) + injector.inject(c => FlushableHashAggregateRule.apply(c.spark)) injector.inject(c => InsertTransitions(c.outputsColumnar)) - injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + injector.inject(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext())) SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() - .foreach(each => injector.inject(c => each(c.session))) + .foreach(each => injector.inject(c => each(c.spark))) injector.inject(c => ColumnarCollapseTransformStages(c.glutenConf)) - injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.session)) + injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.spark)) + injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.spark)) injector.inject(_ => RemoveFallbackTagRule()) } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala index d993e399dbf4..26adcb009f43 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches} import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils} import org.apache.gluten.extension.{GlutenPlan, ValidationResult} @@ -134,7 +133,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( } override protected def doValidateInternal(): ValidationResult = { - if (!GlutenConfig.getConf.enableColumnarPartialProject) { + if (!glutenConf.enableColumnarPartialProject) { return ValidationResult.failed("Config disable this feature") } if (UDFAttrNotExists) { @@ -159,11 +158,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( if (!original.projectList.forall(validateExpression(_))) { return ValidationResult.failed("Contains expression not supported") } - if ( - ExpressionUtils.hasComplexExpressions( - original, - GlutenConfig.getConf.fallbackExpressionsThreshold) - ) { + if (ExpressionUtils.hasComplexExpressions(original, glutenConf.fallbackExpressionsThreshold)) { return ValidationResult.failed("Fallback by complex expression") } ValidationResult.succeeded diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala index a853778484b1..04af9eb1fa6e 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.iterator.Iterators import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators @@ -48,7 +47,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val convertTime = longMetric("convertTime") - val numRows = GlutenConfig.getConf.maxBatchSize + val numRows = glutenConf.maxBatchSize // This avoids calling `schema` in the RDD closure, so that we don't need to include the entire // plan (this) in the closure. val localSchema = schema @@ -68,7 +67,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val convertTime = longMetric("convertTime") - val numRows = GlutenConfig.getConf.maxBatchSize + val numRows = glutenConf.maxBatchSize val mode = BroadcastUtils.getBroadcastMode(outputPartitioning) val relation = child.executeBroadcast() BroadcastUtils.sparkToVeloxUnsafe( diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala index 56a3d86a9038..1a2c291b7206 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (!GlutenConfig.getConf.enableNativeBloomFilter) { + if (!glutenConf.enableNativeBloomFilter) { return plan } val out = plan.transformWithSubqueries { diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 2e5390697795..c009ba5ba50a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -31,10 +31,13 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike * To transform regular aggregation to intermediate aggregation that internally enables * optimizations such as flushing and abandoning. */ -case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { +case class FlushableHashAggregateRule(spark: SparkSession) extends Rule[SparkPlan] { import FlushableHashAggregateRule._ + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + if (!glutenConf.enableVeloxFlushablePartialAggregation) { return plan } plan.transformUpWithPruning(_.containsPattern(EXCHANGE)) { diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala index 8ceee3d573b9..552cc838641b 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala @@ -28,13 +28,16 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXP import org.apache.spark.sql.types._ case class HLLRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { case a: Aggregate => a.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) { case aggExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _, _, _) - if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction && - GlutenConfig.getConf.enableColumnarHashAgg && + if glutenConf.enableNativeHyperLogLogAggregateFunction && + glutenConf.enableColumnarHashAgg && isSupportedDataType(hll.child.dataType) => val hllAdapter = HLLAdapter( hll.child, diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index bf7b84c9b316..149b37c78c88 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -28,11 +28,9 @@ trait ColumnarRuleApplier { object ColumnarRuleApplier { class ColumnarRuleCall( - val session: SparkSession, + val spark: SparkSession, val ac: AdaptiveContext, val outputsColumnar: Boolean) { - val glutenConf: GlutenConfig = { - new GlutenConfig(session.sessionState.conf) - } + val glutenConf: GlutenConfig = new GlutenConfig(spark) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index 498f040c9075..b330870121ef 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -40,12 +40,12 @@ class GlutenInjector private[injector] (control: InjectorControl) { control.disabler().wrapColumnarRule(s => new GlutenColumnarRule(s, applier))) } - private def applier(session: SparkSession): ColumnarRuleApplier = { - val conf = new GlutenConfig(session.sessionState.conf) - if (conf.enableRas) { - return ras.createApplier(session) + private def applier(spark: SparkSession): ColumnarRuleApplier = { + val glutenConf = new GlutenConfig(spark) + if (glutenConf.enableRas) { + return ras.createApplier(spark) } - legacy.createApplier(session) + legacy.createApplier(spark) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index f272dc3eca72..73d277d4a1bd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} import org.apache.gluten.extension.ValidationResult @@ -52,7 +51,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource val fileFormat: ReadFileFormat def getRootFilePaths: Seq[String] = { - if (GlutenConfig.getConf.scanFileSchemeValidationEnabled) { + if (glutenConf.scanFileSchemeValidationEnabled) { getRootPathsInternal } else { Seq.empty diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index ae407b3b3efa..0bb80af5eecf 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -180,7 +180,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } override protected def doValidateInternal(): ValidationResult = { - if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { + if (!glutenConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { return ValidationResult.failed( s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala index e1dfd3f5704a..13ebb93eb972 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala @@ -132,7 +132,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f val sparkConf: SparkConf = sparkContext.getConf val serializableHadoopConf: SerializableConfiguration = new SerializableConfiguration( sparkContext.hadoopConfiguration) - val numaBindingInfo: GlutenNumaBindingInfo = GlutenConfig.getConf.numaBindingInfo + val numaBindingInfo: GlutenNumaBindingInfo = glutenConf.numaBindingInfo @transient private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None @@ -277,7 +277,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f }( t => logOnLevel( - GlutenConfig.getConf.substraitPlanLogLevel, + glutenConf.substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms.")) val inputRDDs = new ColumnarInputRDDsWrapper(columnarInputRDDs) // Check if BatchScan exists. diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 4902b6c6cf1b..c209e634c2eb 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression._ import org.apache.gluten.extension.ValidationResult @@ -83,7 +82,7 @@ case class WindowExecTransformer( val windowParametersStr = new StringBuffer("WindowParameters:") // isStreaming: 1 for streaming, 0 for sort val isStreaming: Int = - if (GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")) 1 else 0 + if (glutenConf.veloxColumnarWindowType.equals("streaming")) 1 else 0 windowParametersStr .append("isStreaming=") diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala index 856d208eada2..da697106dc71 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala @@ -54,9 +54,8 @@ object ValidationResult { /** Every Gluten Operator should extend this trait. */ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelUtil { - protected lazy val enableNativeValidation = glutenConf.enableNativeValidation - protected def glutenConf: GlutenConfig = GlutenConfig.getConf + protected def glutenConf: GlutenConfig = new GlutenConfig(session) /** * Validate whether this SparkPlan supports to be transformed into substrait node in Native Code. diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index 794186bfa957..8a2de504602d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -159,30 +159,32 @@ object FallbackTags { } } -case class FallbackOnANSIMode(session: SparkSession) extends Rule[SparkPlan] { +case class FallbackOnANSIMode(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (GlutenConfig.getConf.enableAnsiMode) { + if (glutenConf.enableAnsiMode) { plan.foreach(FallbackTags.add(_, "does not support ansi mode")) } plan } } -case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] { - lazy val glutenConf: GlutenConfig = GlutenConfig.getConf - lazy val physicalJoinOptimize = glutenConf.enablePhysicalJoinOptimize - lazy val optimizeLevel: Integer = glutenConf.physicalJoinOptimizationThrottle +case class FallbackMultiCodegens(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf: GlutenConfig = new GlutenConfig(spark) def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - if ((count + 1) >= optimizeLevel) return true + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: ShuffledHashJoinExec => - if ((count + 1) >= optimizeLevel) return true + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) - case plan: SortMergeJoinExec if GlutenConfig.getConf.forceShuffledHashJoin => - if ((count + 1) >= optimizeLevel) return true + case plan: SortMergeJoinExec if glutenConf.forceShuffledHashJoin => + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) case _ => false } @@ -232,7 +234,7 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] } override def apply(plan: SparkPlan): SparkPlan = { - if (physicalJoinOptimize) { + if (glutenConf.enablePhysicalJoinOptimize) { tagOnFallbackForMultiCodegens(plan) } else plan } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala index 481e16b0a5be..058e72ee362a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala @@ -52,12 +52,11 @@ case class GlutenFallbackReporter(glutenConf: GlutenConfig, spark: SparkSession) } private def printFallbackReason(plan: SparkPlan): Unit = { - val validationLogLevel = glutenConf.validationLogLevel plan.foreachUp { case _: GlutenPlan => // ignore case p: SparkPlan if FallbackTags.nonEmpty(p) => val tag = FallbackTags.get(p) - logFallbackReason(validationLogLevel, p.nodeName, tag.reason()) + logFallbackReason(glutenConf.validationLogLevel, p.nodeName, tag.reason()) // With in next round stage in AQE, the physical plan would be a new instance that // can not preserve the tag, so we need to set the fallback reason to logical plan. // Then we can be aware of the fallback reason for the whole plan. diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index fc31289119a1..b7ed2c2631a9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.{GlutenPlan, GlutenSessionExtensions} import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, RemoveFallbackTagRule} @@ -171,7 +170,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 3e776721311c..31cdb5c08b04 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.{GlutenPlan, GlutenSessionExtensions} import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackTags, RemoveFallbackTagRule} @@ -181,7 +180,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 3e776721311c..31cdb5c08b04 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.{GlutenPlan, GlutenSessionExtensions} import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackTags, RemoveFallbackTagRule} @@ -181,7 +180,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index a214d9755e69..32a83ac63a00 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.{GlutenPlan, GlutenSessionExtensions} import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackTags, RemoveFallbackTagRule} @@ -182,7 +181,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index a28a7d26b386..7b17bf61ff16 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -18,6 +18,7 @@ package org.apache.gluten import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import com.google.common.collect.ImmutableList @@ -34,9 +35,13 @@ case class GlutenNumaBindingInfo( totalCoreRange: Array[String] = null, numCoresPerExecutor: Int = -1) {} -class GlutenConfig(conf: SQLConf) extends Logging { +class GlutenConfig(sessionOpt: Option[SparkSession] = None) extends Logging { import GlutenConfig._ + def this(spark: SparkSession) = this(Some(spark)) + + def conf: SQLConf = sessionOpt.map(_.sessionState.conf).getOrElse(SQLConf.get) + def enableAnsiMode: Boolean = conf.ansiEnabled def enableGluten: Boolean = conf.getConf(GLUTEN_ENABLED) @@ -648,9 +653,7 @@ object GlutenConfig { var ins: GlutenConfig = _ - def getConf: GlutenConfig = { - new GlutenConfig(SQLConf.get) - } + def getConf: GlutenConfig = new GlutenConfig() @deprecated def getTempFile: String = synchronized {