diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 9884a0c6ef39..41ffbdb58354 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -53,6 +53,7 @@ class CHBackend extends Backend { override def validatorApi(): ValidatorApi = new CHValidatorApi override def metricsApi(): MetricsApi = new CHMetricsApi override def listenerApi(): ListenerApi = new CHListenerApi + override def ruleApi(): RuleApi = new CHRuleApi override def settings(): BackendSettingsApi = CHBackendSettings } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala new file mode 100644 index 000000000000..253285f1bbaa --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi.clickhouse + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.backendsapi.RuleApi +import org.apache.gluten.extension._ +import org.apache.gluten.extension.columnar._ +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager +import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.parser.GlutenClickhouseSqlParser +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} +import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} +import org.apache.spark.util.SparkPlanRules + +class CHRuleApi extends RuleApi { + import CHRuleApi._ + override def injectRules(injector: RuleInjector): Unit = { + injectSpark(injector.spark) + injectGluten(injector.gluten) + injectRas(injector.ras) + } +} + +private object CHRuleApi { + def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + // Regular Spark rules. + injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) + injector.injectParser( + (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule( + spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf)) + injector.injectResolutionRule( + spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) + injector.injectOptimizerRule( + spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) + injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) + injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) + injector.injectOptimizerRule(_ => EqualToRewrite) + + } + + def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + // Gluten columnar: Transform rules. + injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) + injector.injectTransform(_ => RewriteSubqueryBroadcast()) + injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session)) + injector.injectTransform(_ => FallbackEmptySchemaRelation()) + injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectTransform(_ => RewriteSparkPlanRulesManager()) + injector.injectTransform(_ => AddFallbackTagRule()) + injector.injectTransform(_ => TransformPreOverrides()) + injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) + injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => EnsureLocalSortRequirements) + injector.injectTransform(_ => EliminateLocalSort) + injector.injectTransform(_ => CollapseProjectExecTransformer) + injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) + + // Gluten columnar: Fallback policies. + injector.injectFallbackPolicy( + c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) + + // Gluten columnar: Post rules. + injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.injectPost(c => each(c.session))) + injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + + // Gluten columnar: Final rules. + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(_ => RemoveFallbackTagRule()) + } + + def injectRas(injector: RuleInjector.RasInjector): Unit = { + // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any + // execution calls. + injector.inject( + _ => + new SparkPlanRules.AbortRule( + "Clickhouse backend doesn't yet have RAS support, please try disabling RAS and" + + " rerun the application")) + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 8fdc2645a5fb..02b4777e7120 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,11 +21,9 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention -import org.apache.gluten.parser.GlutenClickhouseSqlParser import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy} @@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} import org.apache.spark.shuffle.utils.CHShuffleUtil -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet} import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RangePartitioning} -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec @@ -549,82 +542,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { ClickHouseBuildSideRelation(mode, newOutput, batches.flatten, rowCount, newBuildKeys) } - /** - * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse backend. - * - * @return - */ - override def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] = { - List.empty - } - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = { - List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark)) - } - - /** - * Generate extended Analyzers. Currently only for ClickHouse backend. - * - * @return - */ - override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { - List( - spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf), - spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) - } - - /** - * Generate extended Optimizers. - * - * @return - */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { - List( - spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf), - spark => CHAggregateFunctionRewriteRule(spark), - _ => CountDistinctWithoutExpand, - _ => EqualToRewrite - ) - } - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => FallbackBroadcastHashJoin(spark)) - - /** - * Generate extended columnar pre-rules. - * - * @return - */ - override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => RewriteSortMergeJoinToHashJoinRule(spark)) - - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { - List() - } - - /** - * Generate extended Strategies. - * - * @return - */ - override def genExtendedStrategies(): List[SparkSession => Strategy] = - List() - - override def genInjectExtendedParser() - : List[(SparkSession, ParserInterface) => ParserInterface] = { - List((spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) - } - /** Define backend specfic expression mappings. */ override def extraExpressionMappings: Seq[Sig] = { List( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index d32911f4a4c7..21175f20eb64 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -55,6 +55,7 @@ class VeloxBackend extends Backend { override def validatorApi(): ValidatorApi = new VeloxValidatorApi override def metricsApi(): MetricsApi = new VeloxMetricsApi override def listenerApi(): ListenerApi = new VeloxListenerApi + override def ruleApi(): RuleApi = new VeloxRuleApi override def settings(): BackendSettingsApi = VeloxBackendSettings } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala new file mode 100644 index 000000000000..34180eceaf3a --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi.velox + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.backendsapi.RuleApi +import org.apache.gluten.datasource.ArrowConvertorRule +import org.apache.gluten.extension.{ArrowScanReplaceRule, BloomFilterMightContainJointRewriteRule, CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule, RuleInjector} +import org.apache.gluten.extension.columnar.{AddFallbackTagRule, CollapseProjectExecTransformer, EliminateLocalSort, EnsureLocalSortRequirements, ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackMultiCodegens, FallbackOnANSIMode, MergeTwoPhasesHashBaseAggregate, PlanOneRowRelation, RemoveFallbackTagRule, RemoveNativeWriteFilesSortAndProject, RewriteTransformer} +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform +import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager +import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} +import org.apache.spark.sql.expression.UDFResolver +import org.apache.spark.util.SparkPlanRules + +class VeloxRuleApi extends RuleApi { + import VeloxRuleApi._ + + override def injectRules(injector: RuleInjector): Unit = { + injectSpark(injector.spark) + injectGluten(injector.gluten) + injectRas(injector.ras) + } +} + +private object VeloxRuleApi { + def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + // Regular Spark rules. + injector.injectOptimizerRule(CollectRewriteRule.apply) + injector.injectOptimizerRule(HLLRewriteRule.apply) + UDFResolver.getFunctionSignatures.foreach(injector.injectFunction) + injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) + } + + def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + // Gluten columnar: Transform rules. + injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) + injector.injectTransform(_ => RewriteSubqueryBroadcast()) + injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) + injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session)) + injector.injectTransform(_ => FallbackEmptySchemaRelation()) + injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectTransform(_ => RewriteSparkPlanRulesManager()) + injector.injectTransform(_ => AddFallbackTagRule()) + injector.injectTransform(_ => TransformPreOverrides()) + injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) + injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => EnsureLocalSortRequirements) + injector.injectTransform(_ => EliminateLocalSort) + injector.injectTransform(_ => CollapseProjectExecTransformer) + if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) + } + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) + + // Gluten columnar: Fallback policies. + injector.injectFallbackPolicy( + c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) + + // Gluten columnar: Post rules. + injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.injectPost(c => each(c.session))) + injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + + // Gluten columnar: Final rules. + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(_ => RemoveFallbackTagRule()) + } + + def injectRas(injector: RuleInjector.RasInjector): Unit = { + // Gluten RAS: Pre rules. + injector.inject(_ => RemoveTransitions) + injector.inject(c => FallbackOnANSIMode.apply(c.session)) + injector.inject(c => PlanOneRowRelation.apply(c.session)) + injector.inject(_ => FallbackEmptySchemaRelation()) + injector.inject(_ => RewriteSubqueryBroadcast()) + injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) + injector.inject(c => ArrowScanReplaceRule.apply(c.session)) + injector.inject(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + + // Gluten RAS: The RAS rule. + injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar)) + + // Gluten RAS: Post rules. + injector.inject(_ => RemoveTransitions) + injector.inject(_ => RemoveNativeWriteFilesSortAndProject()) + injector.inject(c => RewriteTransformer.apply(c.session)) + injector.inject(_ => EnsureLocalSortRequirements) + injector.inject(_ => EliminateLocalSort) + injector.inject(_ => CollapseProjectExecTransformer) + if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + injector.inject(c => FlushableHashAggregateRule.apply(c.session)) + } + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => InsertTransitions(c.outputsColumnar)) + injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.inject(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.inject(_ => RemoveFallbackTagRule()) + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index fd0fc62dcbb6..bd390004feda 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -18,12 +18,10 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.SparkPlanExecApi -import org.apache.gluten.datasource.ArrowConvertorRule import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} -import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride @@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} import org.apache.spark.shuffle.utils.ShuffleUtil -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.FileFormat @@ -56,7 +49,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBr import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.utils.ExecUtil -import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction} +import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -65,8 +58,6 @@ import org.apache.commons.lang3.ClassUtils import javax.ws.rs.core.UriBuilder -import scala.collection.mutable.ListBuffer - class VeloxSparkPlanExecApi extends SparkPlanExecApi { /** The columnar-batch type this backend is using. */ @@ -760,74 +751,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } } - /** - * * Rules and strategies. - */ - - /** - * Generate extended DataSourceV2 Strategy. - * - * @return - */ - override def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] = List() - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = List() - - /** - * Generate extended Analyzer. - * - * @return - */ - override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = List() - - /** - * Generate extended Optimizer. Currently only for Velox backend. - * - * @return - */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = - List(CollectRewriteRule.apply, HLLRewriteRule.apply) - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = { - List(BloomFilterMightContainJointRewriteRule.apply, ArrowScanReplaceRule.apply) - } - - /** - * Generate extended columnar pre-rules. - * - * @return - */ - override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = { - val buf: ListBuffer[SparkSession => Rule[SparkPlan]] = ListBuffer() - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - buf += FlushableHashAggregateRule.apply - } - buf.result - } - - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { - List(ArrowConvertorRule) - } - - /** - * Generate extended Strategy. - * - * @return - */ - override def genExtendedStrategies(): List[SparkSession => Strategy] = { - List() - } - /** Define backend specfic expression mappings. */ override def extraExpressionMappings: Seq[Sig] = { Seq( @@ -844,11 +767,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { ) } - override def genInjectedFunctions() - : Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { - UDFResolver.getFunctionSignatures - } - override def rewriteSpillPath(path: String): String = { val fs = GlutenConfig.getConf.veloxSpillFileSystem fs match { diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index dbf927909187..6e3484dfa969 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -17,12 +17,11 @@ package org.apache.gluten import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY -import org.apache.gluten.GlutenPlugin.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.events.GlutenBuildInfoEvent import org.apache.gluten.exception.GlutenException import org.apache.gluten.expression.ExpressionMappings -import org.apache.gluten.extension.{ColumnarOverrides, OthersExtensionOverrides, QueryStagePrepOverrides} +import org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.test.TestStats import org.apache.gluten.utils.TaskListener @@ -31,14 +30,13 @@ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, import org.apache.spark.internal.Logging import org.apache.spark.listener.GlutenListenerFactory import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.execution.ui.GlutenEventUtils -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.{SparkResourceUtil, TaskResources} import java.util -import java.util.{Collections, Objects} +import java.util.Collections import scala.collection.mutable @@ -298,25 +296,4 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { } } -private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { - override def apply(exts: SparkSessionExtensions): Unit = { - GlutenPlugin.DEFAULT_INJECTORS.foreach(injector => injector.inject(exts)) - } -} - -private[gluten] trait GlutenSparkExtensionsInjector { - def inject(extensions: SparkSessionExtensions): Unit -} - -private[gluten] object GlutenPlugin { - val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key - val GLUTEN_SESSION_EXTENSION_NAME: String = - Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName) - - /** Specify all injectors that Gluten is using in following list. */ - val DEFAULT_INJECTORS: List[GlutenSparkExtensionsInjector] = List( - QueryStagePrepOverrides, - ColumnarOverrides, - OthersExtensionOverrides - ) -} +private object GlutenPlugin {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala index 2c465ac61993..3a597552207b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala @@ -33,6 +33,8 @@ trait Backend { def listenerApi(): ListenerApi + def ruleApi(): RuleApi + def settings(): BackendSettingsApi } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala index f2c93d8c70fc..16aa9161eba0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala @@ -83,6 +83,10 @@ object BackendsApiManager { backend.metricsApi() } + def getRuleApiInstance: RuleApi = { + backend.ruleApi() + } + def getSettings: BackendSettingsApi = { backend.settings } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala new file mode 100644 index 000000000000..951317d6580e --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi + +import org.apache.gluten.extension.RuleInjector + +trait RuleApi { + def injectRules(injector: RuleInjector): Unit +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 3b9e87a2055a..0227ed5da127 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -27,20 +27,14 @@ import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, FileSourceScanExec, GenerateExec, LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -392,69 +386,6 @@ trait SparkPlanExecApi { child: SparkPlan, evalType: Int): SparkPlan - /** - * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse backend. - * - * @return - */ - def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended Analyzers. Currently only for ClickHouse backend. - * - * @return - */ - def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] - - /** - * Generate extended Optimizers. Currently only for Velox backend. - * - * @return - */ - def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] - - /** - * Generate extended Strategies - * - * @return - */ - def genExtendedStrategies(): List[SparkSession => Strategy] - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended columnar transform-rules. - * - * @return - */ - def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended columnar post-rules. - * - * @return - */ - def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { - SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() - } - - def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] - - def genInjectExtendedParser(): List[(SparkSession, ParserInterface) => ParserInterface] = - List.empty - def genGetStructFieldTransformer( substraitExprName: String, childTransformer: ExpressionTransformer, @@ -665,8 +596,6 @@ trait SparkPlanExecApi { } } - def genInjectedFunctions(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = Seq.empty - def rewriteSpillPath(path: String): String = path /** diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala index 067976b63b2c..eb21937994f6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala @@ -16,17 +16,14 @@ */ package org.apache.gluten.extension -import org.apache.gluten.{GlutenConfig, GlutenSparkExtensionsInjector} import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier -import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.Transitions import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule @@ -95,7 +92,7 @@ object ColumnarOverrideRules { } } -case class ColumnarOverrideRules(session: SparkSession) +case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApplier) extends ColumnarRule with Logging with LogLevelUtil { @@ -117,19 +114,10 @@ case class ColumnarOverrideRules(session: SparkSession) val outputsColumnar = OutputsColumnarTester.inferOutputsColumnar(plan) val unwrapped = OutputsColumnarTester.unwrap(plan) val vanillaPlan = Transitions.insertTransitions(unwrapped, outputsColumnar) - val applier: ColumnarRuleApplier = if (GlutenConfig.getConf.enableRas) { - new EnumeratedApplier(session) - } else { - new HeuristicApplier(session) - } val out = applier.apply(vanillaPlan, outputsColumnar) out } } -object ColumnarOverrides extends GlutenSparkExtensionsInjector { - override def inject(extensions: SparkSessionExtensions): Unit = { - extensions.injectColumnar(spark => ColumnarOverrideRules(spark)) - } -} +object ColumnarOverrides {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala new file mode 100644 index 000000000000..cceb3851a0f7 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.BackendsApiManager + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.internal.StaticSQLConf + +import java.util.Objects + +private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { + override def apply(exts: SparkSessionExtensions): Unit = { + val injector = new RuleInjector() + BackendsApiManager.getRuleApiInstance.injectRules(injector) + injector.inject(exts) + } +} + +private[gluten] object GlutenSessionExtensions { + val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key + val GLUTEN_SESSION_EXTENSION_NAME: String = + Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName) +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala deleted file mode 100644 index f2ccf6e81ca1..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenSparkExtensionsInjector -import org.apache.gluten.backendsapi.BackendsApiManager - -import org.apache.spark.sql.SparkSessionExtensions - -object OthersExtensionOverrides extends GlutenSparkExtensionsInjector { - override def inject(extensions: SparkSessionExtensions): Unit = { - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectExtendedParser() - .foreach(extensions.injectParser) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedAnalyzers() - .foreach(extensions.injectResolutionRule) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedOptimizers() - .foreach(extensions.injectOptimizerRule) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedDataSourceV2Strategies() - .foreach(extensions.injectPlannerStrategy) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedStrategies() - .foreach(extensions.injectPlannerStrategy) - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectedFunctions() - .foreach(extensions.injectFunction) - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectPostHocResolutionRules() - .foreach(extensions.injectPostHocResolutionRule) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala deleted file mode 100644 index 8f9e2326ca71..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenSparkExtensionsInjector -import org.apache.gluten.backendsapi.BackendsApiManager - -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -object QueryStagePrepOverrides extends GlutenSparkExtensionsInjector { - private val RULES: Seq[SparkSession => Rule[SparkPlan]] = - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedQueryStagePrepRules() - - override def inject(extensions: SparkSessionExtensions): Unit = { - RULES.foreach(extensions.injectQueryStagePrepRule) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala new file mode 100644 index 000000000000..e24d89e79b92 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier +import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +import scala.collection.mutable + +class RuleInjector { + import RuleInjector._ + + val spark: SparkInjector = SparkInjector() + val gluten: GlutenInjector = GlutenInjector() + val ras: RasInjector = RasInjector() + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + spark.inject(extensions) + if (GlutenConfig.getConf.enableRas) { + ras.inject(extensions) + } else { + gluten.inject(extensions) + } + } +} + +object RuleInjector { + class SparkInjector private { + private type RuleBuilder = SparkSession => Rule[LogicalPlan] + private type StrategyBuilder = SparkSession => Strategy + private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] + + private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] + private val parserBuilders = mutable.Buffer.empty[ParserBuilder] + private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + private val optimizerRules = mutable.Buffer.empty[RuleBuilder] + private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] + private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { + queryStagePrepRuleBuilders += builder + } + + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } + + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + def injectFunction(functionDescription: FunctionDescription): Unit = { + injectedFunctions += functionDescription + } + + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) + parserBuilders.foreach(extensions.injectParser) + resolutionRuleBuilders.foreach(extensions.injectResolutionRule) + optimizerRules.foreach(extensions.injectOptimizerRule) + plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) + injectedFunctions.foreach(extensions.injectFunction) + postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) + } + } + + class GlutenInjector private { + private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def injectTransform(builder: ColumnarRuleBuilder): Unit = { + transformBuilders += builder + } + + def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { + fallbackPolicyBuilders += builder + } + + def injectPost(builder: ColumnarRuleBuilder): Unit = { + postBuilders += builder + } + + def injectFinal(builder: ColumnarRuleBuilder): Unit = { + finalBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + val applierBuilder = (session: SparkSession) => + new HeuristicApplier( + session, + transformBuilders, + fallbackPolicyBuilders, + postBuilders, + finalBuilders) + val ruleBuilder = (session: SparkSession) => + new ColumnarOverrideRules(session, applierBuilder(session)) + extensions.injectColumnar(session => ruleBuilder(session)) + } + } + + class RasInjector private { + private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def inject(builder: ColumnarRuleBuilder): Unit = { + ruleBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + val applierBuilder = (session: SparkSession) => new EnumeratedApplier(session, ruleBuilders) + val ruleBuilder = (session: SparkSession) => + new ColumnarOverrideRules(session, applierBuilder(session)) + extensions.injectColumnar(session => ruleBuilder(session)) + } + + } + + private object SparkInjector { + def apply(): SparkInjector = { + new SparkInjector() + } + } + + private object GlutenInjector { + def apply(): GlutenInjector = { + new GlutenInjector() + } + } + private object RasInjector { + def apply(): RasInjector = { + new RasInjector() + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index 27213698b9f2..34beb09937c7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -17,10 +17,12 @@ package org.apache.gluten.extension.columnar import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.metrics.GlutenTimeMetric import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.SparkPlan @@ -30,6 +32,10 @@ trait ColumnarRuleApplier { } object ColumnarRuleApplier { + type ColumnarRuleBuilder = ColumnarRuleCall => Rule[SparkPlan] + + case class ColumnarRuleCall(session: SparkSession, ac: AdaptiveContext, outputsColumnar: Boolean) + class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends RuleExecutor[SparkPlan] { private val batch: Batch = Batch(s"Columnar (Phase [$phase])", Once, rules.map(r => new LoggedRule(r)): _*) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index 5cf3961c548b..ed8a8ba78472 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -16,11 +16,8 @@ */ package org.apache.gluten.extension.columnar.enumerated -import org.apache.gluten.GlutenConfig -import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} -import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder, ColumnarRuleCall} import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector} @@ -28,8 +25,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan} -import org.apache.spark.util.SparkRuleUtil +import org.apache.spark.sql.execution.SparkPlan /** * Columnar rule applier that optimizes, implements Spark plan into Gluten plan by enumerating on @@ -40,7 +36,7 @@ import org.apache.spark.util.SparkRuleUtil * implementing them in EnumeratedTransform. */ @Experimental -class EnumeratedApplier(session: SparkSession) +class EnumeratedApplier(session: SparkSession, ruleBuilders: Seq[ColumnarRuleBuilder]) extends ColumnarRuleApplier with Logging with LogLevelUtil { @@ -53,22 +49,18 @@ class EnumeratedApplier(session: SparkSession) } private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) - override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = + override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { + val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) PhysicalPlanSelector.maybe(session, plan) { - val transformed = - transformPlan("transform", transformRules(outputsColumnar).map(_(session)), plan) - val postPlan = maybeAqe { - transformPlan("post", postRules().map(_(session)), transformed) + val finalPlan = maybeAqe { + apply0(ruleBuilders.map(b => b(call)), plan) } - val finalPlan = transformPlan("final", finalRules().map(_(session)), postPlan) finalPlan } + } - private def transformPlan( - phase: String, - rules: Seq[Rule[SparkPlan]], - plan: SparkPlan): SparkPlan = { - val executor = new ColumnarRuleApplier.Executor(phase, rules) + private def apply0(rules: Seq[Rule[SparkPlan]], plan: SparkPlan): SparkPlan = { + val executor = new ColumnarRuleApplier.Executor("ras", rules) executor.execute(plan) } @@ -80,61 +72,4 @@ class EnumeratedApplier(session: SparkSession) adaptiveContext.resetAdaptiveContext() } } - - /** - * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which - * the plan will be breakdown and decided to be fallen back or not. - */ - private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => RemoveTransitions, - (spark: SparkSession) => FallbackOnANSIMode(spark), - (spark: SparkSession) => PlanOneRowRelation(spark), - (_: SparkSession) => FallbackEmptySchemaRelation(), - (_: SparkSession) => RewriteSubqueryBroadcast() - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: - List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) ::: - List( - (session: SparkSession) => EnumeratedTransform(session, outputsColumnar), - (_: SparkSession) => RemoveTransitions - ) ::: - List( - (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(), - (spark: SparkSession) => RewriteTransformer(spark), - (_: SparkSession) => EnsureLocalSortRequirements, - (_: SparkSession) => EliminateLocalSort, - (_: SparkSession) => CollapseProjectExecTransformer - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() ::: - SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) ::: - List((_: SparkSession) => InsertTransitions(outputsColumnar)) - } - - /** - * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to - * make sure it be able to run and be compatible with Spark's execution engine. - */ - private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = - List( - (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: - List((_: SparkSession) => ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: - SparkRuleUtil.extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) - - /* - * Rules consistently applying to all input plans after all other rules have been applied, despite - * whether the input plan is fallen back or not. - */ - private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - // The rule is required despite whether the stage is fallen back or not. Since - // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule - // when columnar table cache is enabled. - (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s), - (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s), - (_: SparkSession) => RemoveFallbackTagRule() - ) - } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index f776a1dcc3cd..0e4a5876bc92 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -16,26 +16,26 @@ */ package org.apache.gluten.extension.columnar.heuristic -import org.apache.gluten.GlutenConfig -import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} -import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager -import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder, ColumnarRuleCall} import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan} -import org.apache.spark.util.SparkRuleUtil +import org.apache.spark.sql.execution.SparkPlan /** * Columnar rule applier that optimizes, implements Spark plan into Gluten plan by heuristically * applying columnar rules in fixed order. */ -class HeuristicApplier(session: SparkSession) +class HeuristicApplier( + session: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder], + fallbackPolicyBuilders: Seq[ColumnarRuleBuilder], + postBuilders: Seq[ColumnarRuleBuilder], + finalBuilders: Seq[ColumnarRuleBuilder]) extends ColumnarRuleApplier with Logging with LogLevelUtil { @@ -49,27 +49,27 @@ class HeuristicApplier(session: SparkSession) private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - withTransformRules(transformRules(outputsColumnar)).apply(plan) + val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + makeRule(call).apply(plan) } - // Visible for testing. - def withTransformRules(transformRules: Seq[SparkSession => Rule[SparkPlan]]): Rule[SparkPlan] = + private def makeRule(call: ColumnarRuleCall): Rule[SparkPlan] = plan => PhysicalPlanSelector.maybe(session, plan) { val finalPlan = prepareFallback(plan) { p => - val suggestedPlan = transformPlan("transform", transformRules.map(_(session)), p) - transformPlan("fallback", fallbackPolicies().map(_(session)), suggestedPlan) match { + val suggestedPlan = transformPlan("transform", transformRules(call), p) + transformPlan("fallback", fallbackPolicies(call), suggestedPlan) match { case FallbackNode(fallbackPlan) => // we should use vanilla c2r rather than native c2r, // and there should be no `GlutenPlan` any more, // so skip the `postRules()`. fallbackPlan case plan => - transformPlan("post", postRules().map(_(session)), plan) + transformPlan("post", postRules(call), plan) } } - transformPlan("final", finalRules().map(_(session)), finalPlan) + transformPlan("final", finalRules(call), finalPlan) } private def transformPlan( @@ -95,69 +95,32 @@ class HeuristicApplier(session: SparkSession) * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which * the plan will be breakdown and decided to be fallen back or not. */ - private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => RemoveTransitions, - (spark: SparkSession) => FallbackOnANSIMode(spark), - (spark: SparkSession) => FallbackMultiCodegens(spark), - (spark: SparkSession) => PlanOneRowRelation(spark), - (_: SparkSession) => RewriteSubqueryBroadcast() - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: - List( - (_: SparkSession) => FallbackEmptySchemaRelation(), - (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), - (_: SparkSession) => RewriteSparkPlanRulesManager(), - (_: SparkSession) => AddFallbackTagRule() - ) ::: - List((_: SparkSession) => TransformPreOverrides()) ::: - List( - (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(), - (spark: SparkSession) => RewriteTransformer(spark), - (_: SparkSession) => EnsureLocalSortRequirements, - (_: SparkSession) => EliminateLocalSort, - (_: SparkSession) => CollapseProjectExecTransformer - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() ::: - SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) ::: - List((_: SparkSession) => InsertTransitions(outputsColumnar)) + private def transformRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + transformBuilders.map(b => b.apply(call)) } /** * Rules to add wrapper `FallbackNode`s on top of the input plan, as hints to make planner fall * back the whole input plan to the original vanilla Spark plan. */ - private def fallbackPolicies(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => - ExpandFallbackPolicy(adaptiveContext.isAdaptiveContext(), adaptiveContext.originalPlan())) + private def fallbackPolicies(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + fallbackPolicyBuilders.map(b => b.apply(call)) } /** * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to * make sure it be able to run and be compatible with Spark's execution engine. */ - private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = - List( - (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: - List((_: SparkSession) => ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: - SparkRuleUtil.extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) + private def postRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + postBuilders.map(b => b.apply(call)) + } /* * Rules consistently applying to all input plans after all other rules have been applied, despite * whether the input plan is fallen back or not. */ - private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - // The rule is required despite whether the stage is fallen back or not. Since - // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule - // when columnar table cache is enabled. - (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s), - (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s), - (_: SparkSession) => RemoveFallbackTagRule() - ) + private def finalRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + finalBuilders.map(b => b.apply(call)) } // Just for test use. @@ -166,3 +129,5 @@ class HeuristicApplier(session: SparkSession) this } } + +object HeuristicApplier {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala index 4a9d69f8f0b1..e1f594fd36e5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import scala.collection.mutable.ListBuffer +// Since: https://github.com/apache/incubator-gluten/pull/3294. sealed trait AdaptiveContext { def enableAdaptiveContext(): Unit def isAdaptiveContext(): Boolean diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala new file mode 100644 index 000000000000..e5c03f8bd0b2 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +object SparkPlanRules extends Logging { + // Since https://github.com/apache/incubator-gluten/pull/1523 + def extendedColumnarRules(ruleNames: String): Seq[SparkSession => Rule[SparkPlan]] = { + val extendedRules = ruleNames.split(",").filter(_.nonEmpty) + extendedRules.map { + ruleName => session: SparkSession => + try { + val ruleClass = Utils.classForName(ruleName) + val rule = + ruleClass + .getConstructor(classOf[SparkSession]) + .newInstance(session) + .asInstanceOf[Rule[SparkPlan]] + rule + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | _: ClassNotFoundException | _: NoClassDefFoundError) => + logWarning(s"Cannot create extended rule $ruleName", e) + EmptyRule // The rule does nothing. + } + }.toList + } + + object EmptyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan + } + + class AbortRule(message: String) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = + throw new IllegalStateException( + "AbortRule is being executed, this should not happen. Reason: " + message) + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala deleted file mode 100644 index 100ec36d2424..000000000000 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.util - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -object SparkRuleUtil extends Logging { - - /** Add the extended pre/post column rules */ - def extendedColumnarRules( - session: SparkSession, - conf: String - ): List[SparkSession => Rule[SparkPlan]] = { - val extendedRules = conf.split(",").filter(_.nonEmpty) - extendedRules - .map { - ruleStr => - try { - val extensionConfClass = Utils.classForName(ruleStr) - val extensionConf = - extensionConfClass - .getConstructor(classOf[SparkSession]) - .newInstance(session) - .asInstanceOf[Rule[SparkPlan]] - - Some((sparkSession: SparkSession) => extensionConf) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot create extended rule $ruleStr", e) - None - } - } - .filter(_.isDefined) - .map(_.get) - .toList - } -} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 7c7aa08791e8..5d171a36bdd4 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,8 +16,12 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -28,18 +32,20 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -48,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -66,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -86,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -106,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -153,43 +159,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } + + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } // For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } // For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 54d7596b602c..1ce0025f2944 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,17 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -49,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -67,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -87,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -107,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -168,44 +173,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { thread.join(10000) } } +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } -// For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + // For replacing LeafOp. + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -// For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + // For replacing UnaryOp1. + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a4768851..3acc9c4b39aa 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,18 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { - + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -50,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -68,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -88,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -108,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -170,43 +174,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } -// For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } -// For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + // For replacing LeafOp. + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } + + // For replacing UnaryOp1. + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a4768851..bcc4e829b535 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,18 +33,20 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -50,16 +55,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -68,16 +73,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -88,16 +93,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -108,16 +113,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -170,43 +175,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } + + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } // For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } // For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark)))