From 71ec720d428424339ad2db803a7bf588be57eb86 Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Mon, 11 Dec 2023 21:09:29 +0800 Subject: [PATCH] [VL] Make bloom_filter_agg fall back when might_contain is not transformable (#3994) [VL] Make bloom_filter_agg fall back when might_contain is not transformable. --- .../backendsapi/clickhouse/CHBackend.scala | 2 + docs/velox-backend-limitations.md | 31 ++--------- .../backendsapi/BackendSettingsApi.scala | 2 + .../extension/ColumnarOverrides.scala | 1 + .../columnar/TransformHintRule.scala | 49 ++++++++++++++++- .../utils/velox/VeloxTestSettings.scala | 2 - ...GlutenBloomFilterAggregateQuerySuite.scala | 52 ++++++++++++++++++- .../utils/velox/VeloxTestSettings.scala | 2 - ...GlutenBloomFilterAggregateQuerySuite.scala | 52 ++++++++++++++++++- .../glutenproject/sql/shims/SparkShims.scala | 7 ++- .../sql/shims/spark32/Spark32Shims.scala | 7 ++- .../sql/shims/spark33/Spark33Shims.scala | 18 +++++++ .../sql/shims/spark34/Spark34Shims.scala | 18 +++++++ 13 files changed, 206 insertions(+), 37 deletions(-) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala index 1f7dbb9cd905..f474ab5baae1 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala @@ -219,4 +219,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss override def requiredInputFilePaths(): Boolean = true + + override def enableBloomFilterAggFallbackRule(): Boolean = false } diff --git a/docs/velox-backend-limitations.md b/docs/velox-backend-limitations.md index a17075e7d009..1d7bd9cdcaf1 100644 --- a/docs/velox-backend-limitations.md +++ b/docs/velox-backend-limitations.md @@ -11,40 +11,15 @@ Gluten avoids to modify Spark's existing code and use Spark APIs if possible. Ho So you need to ensure preferentially load the Gluten jar to overwrite the jar of vanilla spark. Refer to [How to prioritize loading Gluten jars in Spark](https://github.com/oap-project/gluten/blob/main/docs/velox-backend-troubleshooting.md#incompatible-class-error-when-using-native-writer). - -### Runtime BloomFilter - -Velox BloomFilter's implementation is different from Spark's. So if `might_contain` falls back, but `bloom_filter_agg` is offloaded to velox, an exception will be thrown. - -#### example - -```sql -SELECT might_contain(null, null) both_null, - might_contain(null, 1L) null_bf, - might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), - null) null_value -``` - -The below exception will be thrown. - -``` -Unexpected Bloom filter version number (512) -java.io.IOException: Unexpected Bloom filter version number (512) - at org.apache.spark.util.sketch.BloomFilterImpl.readFrom0(BloomFilterImpl.java:256) - at org.apache.spark.util.sketch.BloomFilterImpl.readFrom(BloomFilterImpl.java:265) - at org.apache.spark.util.sketch.BloomFilter.readFrom(BloomFilter.java:178) -``` - -#### Solution - -Set the gluten config `spark.gluten.sql.native.bloomFilter=false` to fall back to vanilla bloom filter, you can also disable runtime filter by setting spark config `spark.sql.optimizer.runtime.bloomFilter.enabled=false`. - ### Fallbacks Except the unsupported operators, functions, file formats, data sources listed in , there are some known cases also fall back to Vanilla Spark. #### ANSI Gluten currently doesn't support ANSI mode. If ANSI is enabled, Spark plan's execution will always fall back to vanilla Spark. +#### Runtime BloomFilter +Velox BloomFilter's serialization format is different from Spark's. BloomFilter binary generated by Velox can't be deserialized by vanilla spark. So if `might_contain` falls back, we fall back `bloom_filter_agg` to vanilla spark also. + #### Case Sensitive mode Gluten only supports spark default case-insensitive mode. If case-sensitive mode is enabled, user may get incorrect result. diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index a6443060a4ed..feb09e6c0fc0 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -111,4 +111,6 @@ trait BackendSettingsApi { def staticPartitionWriteOnly(): Boolean = false def requiredInputFilePaths(): Boolean = false + + def enableBloomFilterAggFallbackRule(): Boolean = true } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 0bd693aa1349..4ee22c0bc24c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -795,6 +795,7 @@ case class ColumnarOverrideRules(session: SparkSession) (spark: SparkSession) => PlanOneRowRelation(spark), (_: SparkSession) => FallbackEmptySchemaRelation(), (_: SparkSession) => AddTransformHintRule(), + (_: SparkSession) => FallbackBloomFilterAggIfNeeded(), (_: SparkSession) => TransformPreOverrides(isAdaptiveContext), (spark: SparkSession) => RewriteTransformer(spark), (_: SparkSession) => EnsureLocalSortRequirements diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala index 0d1f3bd05765..45029c0ee503 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala @@ -20,12 +20,13 @@ import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.execution._ import io.glutenproject.extension.{GlutenPlan, ValidationResult} +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.utils.PhysicalPlanSelector import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.rules.Rule @@ -269,6 +270,52 @@ case class FallbackEmptySchemaRelation() extends Rule[SparkPlan] { } } +/** + * Velox BloomFilter's implementation is different from Spark's. So if might_contain falls back, we + * need fall back related bloom filter agg. + */ +case class FallbackBloomFilterAggIfNeeded() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = + if ( + GlutenConfig.getConf.enableNativeBloomFilter && + BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule() + ) { + plan.transformDown { + case p if TransformHints.isAlreadyTagged(p) && TransformHints.isNotTransformable(p) => + handleBloomFilterFallback(p) + p + } + } else { + plan + } + + object SubPlanFromBloomFilterMightContain { + def unapply(expr: Expression): Option[SparkPlan] = + SparkShimLoader.getSparkShims.extractSubPlanFromMightContain(expr) + } + + private def handleBloomFilterFallback(plan: SparkPlan): Unit = { + def tagNotTransformableRecursive(p: SparkPlan): Unit = { + p match { + case agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec + if SparkShimLoader.getSparkShims.hasBloomFilterAggregate(agg) => + TransformHints.tagNotTransformable(agg, "related BloomFilterMightContain falls back") + tagNotTransformableRecursive(agg.child) + case a: org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec => + tagNotTransformableRecursive(a.executedPlan) + case _ => + p.children.map(tagNotTransformableRecursive) + } + } + + plan.transformExpressions { + case expr @ SubPlanFromBloomFilterMightContain(p: SparkPlan) => + tagNotTransformableRecursive(p) + expr + } + } +} + // This rule will try to convert a plan into plan transformer. // The doValidate function will be called to check if the conversion is supported. // If false is returned or any unsupported exception is thrown, a row guard will diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index f28a9c31a40a..756707679822 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("string split function with positive limit") .exclude("string split function with negative limit") enableSuite[GlutenBloomFilterAggregateQuerySuite] - // fallback might_contain, the input argument binary is not same with vanilla spark - .exclude("Test NULL inputs for might_contain") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] enableSuite[GlutenDataSourceV2FunctionSuite] diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala index 744271e53a78..7351dcc4dd35 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql import io.glutenproject.GlutenConfig +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.execution.HashAggregateExecBaseTransformer +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf class GlutenBloomFilterAggregateQuerySuite extends BloomFilterAggregateQuerySuite - with GlutenSQLTestsTrait { + with GlutenSQLTestsTrait + with AdaptiveSparkPlanHelper { import testImplicits._ test("Test bloom_filter_agg with big RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS") { @@ -63,4 +67,50 @@ class GlutenBloomFilterAggregateQuerySuite | from range(1, 1)), null)""".stripMargin), Row(null)) } + + test("Test bloom_filter_agg fallback") { + val table = "bloom_filter_test" + val numEstimatedItems = 5000000L + val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits + val sqlString = s""" + |SELECT col positive_membership_test + |FROM $table + |WHERE might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM $table), col) + """.stripMargin + withTempView(table) { + (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L)) + .toDF("col") + .createOrReplaceTempView(table) + withSQLConf( + GlutenConfig.COLUMNAR_PROJECT_ENABLED.key -> "false" + ) { + val df = spark.sql(sqlString) + df.collect + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h + }.size == 2, + df.queryExecution.executedPlan + ) + } + if (BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule()) { + withSQLConf( + GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false" + ) { + val df = spark.sql(sqlString) + df.collect + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h + }.size == 0, + df.queryExecution.executedPlan + ) + } + } + } + } } diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index dc0762a79d20..1010414ceef4 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("string split function with positive limit") .exclude("string split function with negative limit") enableSuite[GlutenBloomFilterAggregateQuerySuite] - // fallback might_contain, the input argument binary is not same with vanilla spark - .exclude("Test NULL inputs for might_contain") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] enableSuite[GlutenDataSourceV2FunctionSuite] diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala index 744271e53a78..7351dcc4dd35 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql import io.glutenproject.GlutenConfig +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.execution.HashAggregateExecBaseTransformer +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf class GlutenBloomFilterAggregateQuerySuite extends BloomFilterAggregateQuerySuite - with GlutenSQLTestsTrait { + with GlutenSQLTestsTrait + with AdaptiveSparkPlanHelper { import testImplicits._ test("Test bloom_filter_agg with big RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS") { @@ -63,4 +67,50 @@ class GlutenBloomFilterAggregateQuerySuite | from range(1, 1)), null)""".stripMargin), Row(null)) } + + test("Test bloom_filter_agg fallback") { + val table = "bloom_filter_test" + val numEstimatedItems = 5000000L + val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits + val sqlString = s""" + |SELECT col positive_membership_test + |FROM $table + |WHERE might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM $table), col) + """.stripMargin + withTempView(table) { + (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L)) + .toDF("col") + .createOrReplaceTempView(table) + withSQLConf( + GlutenConfig.COLUMNAR_PROJECT_ENABLED.key -> "false" + ) { + val df = spark.sql(sqlString) + df.collect + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h + }.size == 2, + df.queryExecution.executedPlan + ) + } + if (BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule()) { + withSQLConf( + GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false" + ) { + val df = spark.sql(sqlString) + df.collect + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h + }.size == 0, + df.queryExecution.executedPlan + ) + } + } + } + } } diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index de99e7efb44c..ed833dd02738 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan @@ -81,4 +81,9 @@ trait SparkShims { start: Long, length: Long, @transient locations: Array[String] = Array.empty): PartitionedFile + + def hasBloomFilterAggregate( + agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean + + def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] } diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala index 580cc93bdf75..0e1a1fb09c5b 100644 --- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -101,4 +101,9 @@ class Spark32Shims extends SparkShims { length: Long, @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length, locations) + + override def hasBloomFilterAggregate( + agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = false + + override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = None } diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 50e536610e7b..a1e34aab08ee 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -127,6 +127,24 @@ class Spark33Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length, locations) + override def hasBloomFilterAggregate( + agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = { + agg.aggregateExpressions.exists( + expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate]) + } + + override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = { + expr match { + case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) => + Some(sub.plan) + case mc @ BloomFilterMightContain( + g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _), + _) => + Some(sub.plan) + case _ => None + } + } + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE", diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala index cdc42f3b43fd..e23888e38a82 100644 --- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala @@ -130,6 +130,24 @@ class Spark34Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) + override def hasBloomFilterAggregate( + agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = { + agg.aggregateExpressions.exists( + expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate]) + } + + override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = { + expr match { + case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) => + Some(sub.plan) + case mc @ BloomFilterMightContain( + g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _), + _) => + Some(sub.plan) + case _ => None + } + } + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE",