From fa290b719c8b65330f07968350e3798f0bb18da9 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sun, 27 Oct 2024 18:16:03 +0800 Subject: [PATCH] d2 --- .../backendsapi/clickhouse/CHRuleApi.scala | 14 +++++----- .../CommonSubexpressionEliminateRule.scala | 8 ++++-- .../FallbackBroadcastHashJoinRules.scala | 23 ++++++++------- .../MergeTwoPhasesHashBaseAggregate.scala | 11 +++----- .../RewriteDateTimestampComparisonRule.scala | 8 +++--- .../RewriteToDateExpresstionRule.scala | 8 +++--- .../CHAggregateFunctionRewriteRule.scala | 7 +++-- .../backendsapi/velox/VeloxRuleApi.scala | 8 +++--- .../ColumnarPartialProjectExec.scala | 4 +-- .../execution/RowToVeloxColumnarExec.scala | 4 +-- ...omFilterMightContainJointRewriteRule.scala | 5 +++- .../FlushableHashAggregateRule.scala | 7 +++-- .../gluten/extension/HLLRewriteRule.scala | 7 +++-- .../columnar/ColumnarRuleApplier.scala | 4 +-- .../extension/injector/GlutenInjector.scala | 8 +++--- .../execution/BasicScanExecTransformer.scala | 3 +- ...oadcastNestedLoopJoinExecTransformer.scala | 2 +- .../execution/WholeStageTransformer.scala | 4 +-- .../execution/WindowExecTransformer.scala | 3 +- .../apache/gluten/extension/GlutenPlan.scala | 3 +- .../extension/columnar/FallbackRules.scala | 28 ++++++++++--------- .../spark/shuffle/GlutenShuffleUtils.scala | 7 ++--- .../execution/GlutenFallbackReporter.scala | 6 ++-- .../execution/FallbackStrategiesSuite.scala | 2 +- .../execution/FallbackStrategiesSuite.scala | 2 +- .../execution/FallbackStrategiesSuite.scala | 2 +- .../execution/FallbackStrategiesSuite.scala | 2 +- .../org/apache/gluten/GlutenConfig.scala | 2 ++ 28 files changed, 102 insertions(+), 90 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 4323dc955833..51e1abed8ace 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -54,11 +54,11 @@ private object CHRuleApi { injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) injector.injectResolutionRule( - spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf)) + spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule( - spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) + spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectOptimizerRule( - spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) + spark => new CommonSubexpressionEliminateRule(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) injector.injectOptimizerRule(_ => EqualToRewrite) @@ -89,7 +89,7 @@ private object CHRuleApi { injector.injectTransform( c => intercept( - SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))) + SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)(c.session))) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -101,14 +101,14 @@ private object CHRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => intercept(each(c.session)))) - injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf)) injector.injectTransform( c => - intercept(SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session))) + intercept(SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.session))) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala index 52e278b3dace..deeb8192f57e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala @@ -33,17 +33,19 @@ import scala.collection.mutable // 2. append two options to spark config // --conf spark.sql.planChangeLog.level=error // --conf spark.sql.planChangeLog.batches=all -class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) +class CommonSubexpressionEliminateRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + private val glutenConf = new GlutenConfig(spark) + private var lastPlan: LogicalPlan = null override def apply(plan: LogicalPlan): LogicalPlan = { val newPlan = if ( - plan.resolved && GlutenConfig.getConf.enableGluten - && GlutenConfig.getConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan) + plan.resolved && glutenConf.enableGluten + && glutenConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan) ) { lastPlan = plan visitPlan(plan) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala index ec465a3c1506..91c40e17332e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala @@ -36,9 +36,11 @@ import scala.util.control.Breaks.{break, breakable} // see each other during transformation. In order to prevent BroadcastExec being transformed // to columnar while BHJ fallbacks, BroadcastExec need to be tagged not transformable when applying // queryStagePrepRules. -case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extends Rule[SparkPlan] { +case class FallbackBroadcastHashJoinPrepQueryStage(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf: GlutenConfig = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - val columnarConf: GlutenConfig = GlutenConfig.getConf plan.foreach { case bhj: BroadcastHashJoinExec => val buildSidePlan = bhj.buildSide match { @@ -53,8 +55,8 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend case Some(exchange @ BroadcastExchangeExec(mode, child)) => val isTransformable = if ( - !columnarConf.enableColumnarBroadcastExchange || - !columnarConf.enableColumnarBroadcastJoin + !glutenConf.enableColumnarBroadcastExchange || + !glutenConf.enableColumnarBroadcastJoin ) { ValidationResult.failed( "columnar broadcast exchange is disabled or " + @@ -107,8 +109,8 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend case Some(exchange @ BroadcastExchangeExec(mode, child)) => val isTransformable = if ( - !GlutenConfig.getConf.enableColumnarBroadcastExchange || - !GlutenConfig.getConf.enableColumnarBroadcastJoin + !glutenConf.enableColumnarBroadcastExchange || + !glutenConf.enableColumnarBroadcastJoin ) { ValidationResult.failed( "columnar broadcast exchange is disabled or " + @@ -146,13 +148,14 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend // columnar rules. case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPlan] { + private val columnarConf: GlutenConfig = new GlutenConfig(session) + private val enableColumnarBroadcastJoin: Boolean = - GlutenConfig.getConf.enableColumnarBroadcastJoin && - GlutenConfig.getConf.enableColumnarBroadcastExchange + columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange private val enableColumnarBroadcastNestedLoopJoin: Boolean = - GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled && - GlutenConfig.getConf.enableColumnarBroadcastExchange + columnarConf.broadcastNestedLoopJoinTransformerTransformerEnabled && + columnarConf.enableColumnarBroadcastExchange override def apply(plan: SparkPlan): SparkPlan = { plan.foreachUp { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala index 63c5fe017f5e..43f77bda0f1f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/MergeTwoPhasesHashBaseAggregate.scala @@ -34,14 +34,11 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat * Note: this rule must be applied before the `PullOutPreProject` rule, because the * `PullOutPreProject` rule will modify the attributes in some cases. */ -case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) +case class MergeTwoPhasesHashBaseAggregate(spark: SparkSession) extends Rule[SparkPlan] with Logging { - val columnarConf: GlutenConfig = GlutenConfig.getConf - val scanOnly: Boolean = columnarConf.enableScanOnly - val enableColumnarHashAgg: Boolean = !scanOnly && columnarConf.enableColumnarHashAgg - val replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg + private val glutenConf: GlutenConfig = new GlutenConfig(spark) private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = { // TODO: now it can not support to merge agg which there are the filters in the aggregate exprs. @@ -59,7 +56,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) } override def apply(plan: SparkPlan): SparkPlan = { - if (!enableColumnarHashAgg) { + if (!(!glutenConf.enableScanOnly && glutenConf.enableColumnarHashAgg)) { plan } else { plan.transformDown { @@ -111,7 +108,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) _, resultExpressions, child: SortAggregateExec) - if replaceSortAggWithHashAgg && !isStreaming && isPartialAgg(child, sortAgg) => + if glutenConf.forceToUseHashAgg && !isStreaming && isPartialAgg(child, sortAgg) => // convert to complete mode aggregate expressions val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) sortAgg.copy( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala index ea92ddec2c8a..9757c7a87db8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteDateTimestampComparisonRule.scala @@ -37,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String // This rule try to make the filter condition into integer comparison, which is more efficient. // The above example will be rewritten into // select * from table where to_unixtime('2023-11-02', 'yyyy-MM-dd') >= unix_timestamp -class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) +class RewriteDateTimestampComparisonRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging { @@ -54,11 +54,11 @@ class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) "yyyy" ) + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { if ( - plan.resolved && - GlutenConfig.getConf.enableGluten && - GlutenConfig.getConf.enableRewriteDateTimestampComparison + plan.resolved && glutenConf.enableGluten && glutenConf.enableRewriteDateTimestampComparison ) { visitPlan(plan) } else { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala index 34d162d71f5f..a00bd4232d49 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteToDateExpresstionRule.scala @@ -37,15 +37,15 @@ import org.apache.spark.sql.types._ // Under ch backend, the StringType can be directly converted into DateType, // and the functions `from_unixtime` and `unix_timestamp` can be optimized here. // Optimized result is `to_date(stringType)` -class RewriteToDateExpresstionRule(session: SparkSession, conf: SQLConf) +class RewriteToDateExpresstionRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { if ( - plan.resolved && - GlutenConfig.getConf.enableGluten && - GlutenConfig.getConf.enableCHRewriteDateConversion + plan.resolved && glutenConf.enableGluten && glutenConf.enableCHRewriteDateConversion ) { visitPlan(plan) } else { diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala index b11e1e2bb306..4ea59f659ca2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala @@ -31,12 +31,15 @@ import org.apache.spark.sql.types._ * @param spark */ case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case a: Aggregate => a.transformExpressions { case avgExpr @ AggregateExpression(avg: Average, _, _, _, _) - if GlutenConfig.getConf.enableCastAvgAggregateFunction && - GlutenConfig.getConf.enableColumnarHashAgg && + if glutenConf.enableCastAvgAggregateFunction && + glutenConf.enableColumnarHashAgg && !avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) => AggregateExpression( avg.copy(child = Cast(avg.child, DoubleType)), diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index a838c463c390..3554bc5c9c01 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -80,11 +80,11 @@ private object VeloxRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) - injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf)) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } @@ -116,9 +116,9 @@ private object VeloxRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.inject(c => each(c.session))) - injector.inject(c => ColumnarCollapseTransformStages(c.conf)) + injector.inject(c => ColumnarCollapseTransformStages(c.glutenConf)) injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.inject(c => GlutenFallbackReporter(c.conf, c.session)) + injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.session)) injector.inject(_ => RemoveFallbackTagRule()) } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala index 1c394103dbc8..89e530b9490f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala @@ -131,7 +131,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( } override protected def doValidateInternal(): ValidationResult = { - if (!GlutenConfig.getConf.enableColumnarPartialProject) { + if (!glutenConf.enableColumnarPartialProject) { return ValidationResult.failed("Config disable this feature") } if (UDFAttrNotExists) { @@ -159,7 +159,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( if ( ExpressionUtils.hasComplexExpressions( original, - GlutenConfig.getConf.fallbackExpressionsThreshold) + glutenConf.fallbackExpressionsThreshold) ) { return ValidationResult.failed("Fallback by complex expression") } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala index a853778484b1..06f715c1e99e 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala @@ -48,7 +48,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val convertTime = longMetric("convertTime") - val numRows = GlutenConfig.getConf.maxBatchSize + val numRows = glutenConf.maxBatchSize // This avoids calling `schema` in the RDD closure, so that we don't need to include the entire // plan (this) in the closure. val localSchema = schema @@ -68,7 +68,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val convertTime = longMetric("convertTime") - val numRows = GlutenConfig.getConf.maxBatchSize + val numRows = glutenConf.maxBatchSize val mode = BroadcastUtils.getBroadcastMode(outputPartitioning) val relation = child.executeBroadcast() BroadcastUtils.sparkToVeloxUnsafe( diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala index 56a3d86a9038..1a2c291b7206 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (!GlutenConfig.getConf.enableNativeBloomFilter) { + if (!glutenConf.enableNativeBloomFilter) { return plan } val out = plan.transformWithSubqueries { diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 2e5390697795..c009ba5ba50a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -31,10 +31,13 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike * To transform regular aggregation to intermediate aggregation that internally enables * optimizations such as flushing and abandoning. */ -case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { +case class FlushableHashAggregateRule(spark: SparkSession) extends Rule[SparkPlan] { import FlushableHashAggregateRule._ + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + if (!glutenConf.enableVeloxFlushablePartialAggregation) { return plan } plan.transformUpWithPruning(_.containsPattern(EXCHANGE)) { diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala index 8ceee3d573b9..552cc838641b 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala @@ -28,13 +28,16 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXP import org.apache.spark.sql.types._ case class HLLRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { case a: Aggregate => a.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) { case aggExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _, _, _) - if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction && - GlutenConfig.getConf.enableColumnarHashAgg && + if glutenConf.enableNativeHyperLogLogAggregateFunction && + glutenConf.enableColumnarHashAgg && isSupportedDataType(hll.child.dataType) => val hllAdapter = HLLAdapter( hll.child, diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index 27929285bea2..2d8c0aad1d7c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -31,8 +31,6 @@ object ColumnarRuleApplier { val session: SparkSession, val ac: AdaptiveContext, val outputsColumnar: Boolean) { - val conf: GlutenConfig = { - new GlutenConfig(Some(session)) - } + val glutenConf: GlutenConfig = new GlutenConfig(session) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index cbf401432457..147f7c08aa8a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -40,12 +40,12 @@ class GlutenInjector private[injector] (control: InjectorControl) { control.disabler().wrapColumnarRule(s => new GlutenColumnarRule(s, applier))) } - private def applier(session: SparkSession): ColumnarRuleApplier = { - val conf = new GlutenConfig(Some(session)) + private def applier(spark: SparkSession): ColumnarRuleApplier = { + val conf = new GlutenConfig(spark) if (conf.enableRas) { - return ras.createApplier(session) + return ras.createApplier(spark) } - legacy.createApplier(session) + legacy.createApplier(spark) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index d7b824b397e5..3de015f4be2d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} import org.apache.gluten.extension.ValidationResult @@ -52,7 +51,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource val fileFormat: ReadFileFormat def getRootFilePaths: Seq[String] = { - if (GlutenConfig.getConf.scanFileSchemeValidationEnabled) { + if (glutenConf.scanFileSchemeValidationEnabled) { getRootPathsInternal } else { Seq.empty diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index ae407b3b3efa..0bb80af5eecf 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -180,7 +180,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } override protected def doValidateInternal(): ValidationResult = { - if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { + if (!glutenConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { return ValidationResult.failed( s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala index e1dfd3f5704a..13ebb93eb972 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala @@ -132,7 +132,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f val sparkConf: SparkConf = sparkContext.getConf val serializableHadoopConf: SerializableConfiguration = new SerializableConfiguration( sparkContext.hadoopConfiguration) - val numaBindingInfo: GlutenNumaBindingInfo = GlutenConfig.getConf.numaBindingInfo + val numaBindingInfo: GlutenNumaBindingInfo = glutenConf.numaBindingInfo @transient private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None @@ -277,7 +277,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f }( t => logOnLevel( - GlutenConfig.getConf.substraitPlanLogLevel, + glutenConf.substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms.")) val inputRDDs = new ColumnarInputRDDsWrapper(columnarInputRDDs) // Check if BatchScan exists. diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 4902b6c6cf1b..c209e634c2eb 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression._ import org.apache.gluten.extension.ValidationResult @@ -83,7 +82,7 @@ case class WindowExecTransformer( val windowParametersStr = new StringBuffer("WindowParameters:") // isStreaming: 1 for streaming, 0 for sort val isStreaming: Int = - if (GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")) 1 else 0 + if (glutenConf.veloxColumnarWindowType.equals("streaming")) 1 else 0 windowParametersStr .append("isStreaming=") diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala index 856d208eada2..f1fe00c5ac81 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala @@ -54,9 +54,8 @@ object ValidationResult { /** Every Gluten Operator should extend this trait. */ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelUtil { - protected lazy val enableNativeValidation = glutenConf.enableNativeValidation - protected def glutenConf: GlutenConfig = GlutenConfig.getConf + protected val glutenConf: GlutenConfig = new GlutenConfig(session) /** * Validate whether this SparkPlan supports to be transformed into substrait node in Native Code. diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index a5bba46dc605..3499c94f9284 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -159,30 +159,32 @@ object FallbackTags { } } -case class FallbackOnANSIMode(session: SparkSession) extends Rule[SparkPlan] { +case class FallbackOnANSIMode(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf = new GlutenConfig(spark) + override def apply(plan: SparkPlan): SparkPlan = { - if (GlutenConfig.getConf.enableAnsiMode) { + if (glutenConf.enableAnsiMode) { plan.foreach(FallbackTags.add(_, "does not support ansi mode")) } plan } } -case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] { - lazy val columnarConf: GlutenConfig = GlutenConfig.getConf - lazy val physicalJoinOptimize = columnarConf.enablePhysicalJoinOptimize - lazy val optimizeLevel: Integer = columnarConf.physicalJoinOptimizationThrottle +case class FallbackMultiCodegens(spark: SparkSession) extends Rule[SparkPlan] { + + private val glutenConf: GlutenConfig = new GlutenConfig(spark) def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - if ((count + 1) >= optimizeLevel) return true + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: ShuffledHashJoinExec => - if ((count + 1) >= optimizeLevel) return true + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) - case plan: SortMergeJoinExec if GlutenConfig.getConf.forceShuffledHashJoin => - if ((count + 1) >= optimizeLevel) return true + case plan: SortMergeJoinExec if glutenConf.forceShuffledHashJoin => + if ((count + 1) >= glutenConf.physicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiCodegens(_, count + 1)) case _ => false } @@ -232,7 +234,7 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] } override def apply(plan: SparkPlan): SparkPlan = { - if (physicalJoinOptimize) { + if (glutenConf.enablePhysicalJoinOptimize) { tagOnFallbackForMultiCodegens(plan) } else plan } @@ -244,11 +246,11 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] // be added on the top of that plan to prevent actual conversion. case class AddFallbackTagRule() extends Rule[SparkPlan] { import AddFallbackTagRule._ - private val glutenConf: GlutenConfig = GlutenConfig.getConf + private val validator = Validators .builder() .fallbackByHint() - .fallbackIfScanOnlyWithFilterPushed(glutenConf.enableScanOnly) + .fallbackIfScanOnlyWithFilterPushed(GlutenConfig.getConf.enableScanOnly) .fallbackComplexExpressions() .fallbackByBackendSettings() .fallbackByUserOptions() diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala index 29443b59c5f6..f5255ac318e5 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala @@ -52,13 +52,12 @@ object GlutenShuffleUtils { } def getCompressionCodec(conf: SparkConf): String = { - val glutenConfig = GlutenConfig.getConf - glutenConfig.columnarShuffleCodec match { + GlutenConfig.getConf.columnarShuffleCodec match { case Some(codec) => val glutenCodecKey = GlutenConfig.COLUMNAR_SHUFFLE_CODEC.key - if (glutenConfig.columnarShuffleEnableQat) { + if (GlutenConfig.getConf.columnarShuffleEnableQat) { checkCodecValues(glutenCodecKey, codec, GlutenConfig.GLUTEN_QAT_SUPPORTED_CODEC) - } else if (glutenConfig.columnarShuffleEnableIaa) { + } else if (GlutenConfig.getConf.columnarShuffleEnableIaa) { checkCodecValues(glutenCodecKey, codec, GlutenConfig.GLUTEN_IAA_SUPPORTED_CODEC) } else { checkCodecValues( diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala index dcacf7f319f1..058e72ee362a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala @@ -31,12 +31,12 @@ import org.apache.spark.sql.execution.ui.GlutenEventUtils * This rule is used to collect all fallback reason. * 1. print fallback reason for each plan node 2. post all fallback reason using one event */ -case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSession) +case class GlutenFallbackReporter(glutenConf: GlutenConfig, spark: SparkSession) extends Rule[SparkPlan] with LogLevelUtil { override def apply(plan: SparkPlan): SparkPlan = { - if (!glutenConfig.enableFallbackReport) { + if (!glutenConf.enableFallbackReport) { return plan } printFallbackReason(plan) @@ -56,7 +56,7 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio case _: GlutenPlan => // ignore case p: SparkPlan if FallbackTags.nonEmpty(p) => val tag = FallbackTags.get(p) - logFallbackReason(glutenConfig.validationLogLevel, p.nodeName, tag.reason()) + logFallbackReason(glutenConf.validationLogLevel, p.nodeName, tag.reason()) // With in next round stage in AQE, the physical plan would be a new instance that // can not preserve the tag, so we need to set the fallback reason to logical plan. // Then we can be aware of the fallback reason for the whole plan. diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index aefedfb276d4..b7ed2c2631a9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -170,7 +170,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - c => ColumnarCollapseTransformStages(c.conf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index de210b32cf1d..31cdb5c08b04 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -180,7 +180,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - c => ColumnarCollapseTransformStages(c.conf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index de210b32cf1d..31cdb5c08b04 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -180,7 +180,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - c => ColumnarCollapseTransformStages(c.conf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index fe30339a0eb9..32a83ac63a00 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -181,7 +181,7 @@ private object FallbackStrategiesSuite { List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), List( c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), - c => ColumnarCollapseTransformStages(c.conf) + c => ColumnarCollapseTransformStages(c.glutenConf) ), List(_ => RemoveFallbackTagRule()) ) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index b3fe40c893c7..26fa65ecd42a 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -38,6 +38,8 @@ case class GlutenNumaBindingInfo( class GlutenConfig(sessionOpt: Option[SparkSession] = None) extends Logging { import GlutenConfig._ + def this(session: SparkSession) = this(Some(session)) + def conf: SQLConf = sessionOpt.map(_.sessionState.conf).getOrElse(SQLConf.get) def enableAnsiMode: Boolean = conf.ansiEnabled