Skip to content

Commit

Permalink
[VL] Make bloom_filter_agg fall back when might_contain is not transf…
Browse files Browse the repository at this point in the history
…ormable (#3994)

[VL] Make bloom_filter_agg fall back when might_contain is not transformable.
  • Loading branch information
zhli1142015 authored Dec 11, 2023
1 parent 5d2dd86 commit 71ec720
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss

override def requiredInputFilePaths(): Boolean = true

override def enableBloomFilterAggFallbackRule(): Boolean = false
}
31 changes: 3 additions & 28 deletions docs/velox-backend-limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,15 @@ Gluten avoids to modify Spark's existing code and use Spark APIs if possible. Ho

So you need to ensure preferentially load the Gluten jar to overwrite the jar of vanilla spark. Refer to [How to prioritize loading Gluten jars in Spark](https://github.com/oap-project/gluten/blob/main/docs/velox-backend-troubleshooting.md#incompatible-class-error-when-using-native-writer).


### Runtime BloomFilter

Velox BloomFilter's implementation is different from Spark's. So if `might_contain` falls back, but `bloom_filter_agg` is offloaded to velox, an exception will be thrown.

#### example

```sql
SELECT might_contain(null, null) both_null,
might_contain(null, 1L) null_bf,
might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)),
null) null_value
```

The below exception will be thrown.

```
Unexpected Bloom filter version number (512)
java.io.IOException: Unexpected Bloom filter version number (512)
at org.apache.spark.util.sketch.BloomFilterImpl.readFrom0(BloomFilterImpl.java:256)
at org.apache.spark.util.sketch.BloomFilterImpl.readFrom(BloomFilterImpl.java:265)
at org.apache.spark.util.sketch.BloomFilter.readFrom(BloomFilter.java:178)
```

#### Solution

Set the gluten config `spark.gluten.sql.native.bloomFilter=false` to fall back to vanilla bloom filter, you can also disable runtime filter by setting spark config `spark.sql.optimizer.runtime.bloomFilter.enabled=false`.

### Fallbacks
Except the unsupported operators, functions, file formats, data sources listed in , there are some known cases also fall back to Vanilla Spark.

#### ANSI
Gluten currently doesn't support ANSI mode. If ANSI is enabled, Spark plan's execution will always fall back to vanilla Spark.

#### Runtime BloomFilter
Velox BloomFilter's serialization format is different from Spark's. BloomFilter binary generated by Velox can't be deserialized by vanilla spark. So if `might_contain` falls back, we fall back `bloom_filter_agg` to vanilla spark also.

#### Case Sensitive mode
Gluten only supports spark default case-insensitive mode. If case-sensitive mode is enabled, user may get incorrect result.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ trait BackendSettingsApi {
def staticPartitionWriteOnly(): Boolean = false

def requiredInputFilePaths(): Boolean = false

def enableBloomFilterAggFallbackRule(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ case class ColumnarOverrideRules(session: SparkSession)
(spark: SparkSession) => PlanOneRowRelation(spark),
(_: SparkSession) => FallbackEmptySchemaRelation(),
(_: SparkSession) => AddTransformHintRule(),
(_: SparkSession) => FallbackBloomFilterAggIfNeeded(),
(_: SparkSession) => TransformPreOverrides(isAdaptiveContext),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EnsureLocalSortRequirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution._
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.PhysicalPlanSelector

import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -269,6 +270,52 @@ case class FallbackEmptySchemaRelation() extends Rule[SparkPlan] {
}
}

/**
* Velox BloomFilter's implementation is different from Spark's. So if might_contain falls back, we
* need fall back related bloom filter agg.
*/
case class FallbackBloomFilterAggIfNeeded() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan =
if (
GlutenConfig.getConf.enableNativeBloomFilter &&
BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule()
) {
plan.transformDown {
case p if TransformHints.isAlreadyTagged(p) && TransformHints.isNotTransformable(p) =>
handleBloomFilterFallback(p)
p
}
} else {
plan
}

object SubPlanFromBloomFilterMightContain {
def unapply(expr: Expression): Option[SparkPlan] =
SparkShimLoader.getSparkShims.extractSubPlanFromMightContain(expr)
}

private def handleBloomFilterFallback(plan: SparkPlan): Unit = {
def tagNotTransformableRecursive(p: SparkPlan): Unit = {
p match {
case agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
if SparkShimLoader.getSparkShims.hasBloomFilterAggregate(agg) =>
TransformHints.tagNotTransformable(agg, "related BloomFilterMightContain falls back")
tagNotTransformableRecursive(agg.child)
case a: org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec =>
tagNotTransformableRecursive(a.executedPlan)
case _ =>
p.children.map(tagNotTransformableRecursive)
}
}

plan.transformExpressions {
case expr @ SubPlanFromBloomFilterMightContain(p: SparkPlan) =>
tagNotTransformableRecursive(p)
expr
}
}
}

// This rule will try to convert a plan into plan transformer.
// The doValidate function will be called to check if the conversion is supported.
// If false is returned or any unsupported exception is thrown, a row guard will
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("string split function with positive limit")
.exclude("string split function with negative limit")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
// fallback might_contain, the input argument binary is not same with vanilla spark
.exclude("Test NULL inputs for might_contain")
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
enableSuite[GlutenDataSourceV2FunctionSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
package org.apache.spark.sql

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf

class GlutenBloomFilterAggregateQuerySuite
extends BloomFilterAggregateQuerySuite
with GlutenSQLTestsTrait {
with GlutenSQLTestsTrait
with AdaptiveSparkPlanHelper {
import testImplicits._

test("Test bloom_filter_agg with big RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS") {
Expand Down Expand Up @@ -63,4 +67,50 @@ class GlutenBloomFilterAggregateQuerySuite
| from range(1, 1)), null)""".stripMargin),
Row(null))
}

test("Test bloom_filter_agg fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
val sqlString = s"""
|SELECT col positive_membership_test
|FROM $table
|WHERE might_contain(
| (SELECT bloom_filter_agg(col,
| cast($numEstimatedItems as long),
| cast($numBits as long))
| FROM $table), col)
""".stripMargin
withTempView(table) {
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
.toDF("col")
.createOrReplaceTempView(table)
withSQLConf(
GlutenConfig.COLUMNAR_PROJECT_ENABLED.key -> "false"
) {
val df = spark.sql(sqlString)
df.collect
assert(
collectWithSubqueries(df.queryExecution.executedPlan) {
case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
}.size == 2,
df.queryExecution.executedPlan
)
}
if (BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule()) {
withSQLConf(
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false"
) {
val df = spark.sql(sqlString)
df.collect
assert(
collectWithSubqueries(df.queryExecution.executedPlan) {
case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
}.size == 0,
df.queryExecution.executedPlan
)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("string split function with positive limit")
.exclude("string split function with negative limit")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
// fallback might_contain, the input argument binary is not same with vanilla spark
.exclude("Test NULL inputs for might_contain")
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
enableSuite[GlutenDataSourceV2FunctionSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
package org.apache.spark.sql

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf

class GlutenBloomFilterAggregateQuerySuite
extends BloomFilterAggregateQuerySuite
with GlutenSQLTestsTrait {
with GlutenSQLTestsTrait
with AdaptiveSparkPlanHelper {
import testImplicits._

test("Test bloom_filter_agg with big RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS") {
Expand Down Expand Up @@ -63,4 +67,50 @@ class GlutenBloomFilterAggregateQuerySuite
| from range(1, 1)), null)""".stripMargin),
Row(null))
}

test("Test bloom_filter_agg fallback") {
val table = "bloom_filter_test"
val numEstimatedItems = 5000000L
val numBits = GlutenConfig.getConf.veloxBloomFilterMaxNumBits
val sqlString = s"""
|SELECT col positive_membership_test
|FROM $table
|WHERE might_contain(
| (SELECT bloom_filter_agg(col,
| cast($numEstimatedItems as long),
| cast($numBits as long))
| FROM $table), col)
""".stripMargin
withTempView(table) {
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
.toDF("col")
.createOrReplaceTempView(table)
withSQLConf(
GlutenConfig.COLUMNAR_PROJECT_ENABLED.key -> "false"
) {
val df = spark.sql(sqlString)
df.collect
assert(
collectWithSubqueries(df.queryExecution.executedPlan) {
case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
}.size == 2,
df.queryExecution.executedPlan
)
}
if (BackendsApiManager.getSettings.enableBloomFilterAggFallbackRule()) {
withSQLConf(
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false"
) {
val df = spark.sql(sqlString)
df.collect
assert(
collectWithSubqueries(df.queryExecution.executedPlan) {
case h if h.isInstanceOf[HashAggregateExecBaseTransformer] => h
}.size == 0,
df.queryExecution.executedPlan
)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
Expand Down Expand Up @@ -81,4 +81,9 @@ trait SparkShims {
start: Long,
length: Long,
@transient locations: Array[String] = Array.empty): PartitionedFile

def hasBloomFilterAggregate(
agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean

def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan]
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
Expand Down Expand Up @@ -101,4 +101,9 @@ class Spark32Shims extends SparkShims {
length: Long,
@transient locations: Array[String] = Array.empty): PartitionedFile =
PartitionedFile(partitionValues, filePath, start, length, locations)

override def hasBloomFilterAggregate(
agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = false

override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ class Spark33Shims extends SparkShims {
@transient locations: Array[String] = Array.empty): PartitionedFile =
PartitionedFile(partitionValues, filePath, start, length, locations)

override def hasBloomFilterAggregate(
agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = {
agg.aggregateExpressions.exists(
expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate])
}

override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = {
expr match {
case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) =>
Some(sub.plan)
case mc @ BloomFilterMightContain(
g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _),
_) =>
Some(sub.plan)
case _ => None
}
}

private def invalidBucketFile(path: String): Throwable = {
new SparkException(
errorClass = "INVALID_BUCKET_FILE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ class Spark34Shims extends SparkShims {
@transient locations: Array[String] = Array.empty): PartitionedFile =
PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations)

override def hasBloomFilterAggregate(
agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec): Boolean = {
agg.aggregateExpressions.exists(
expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate])
}

override def extractSubPlanFromMightContain(expr: Expression): Option[SparkPlan] = {
expr match {
case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) =>
Some(sub.plan)
case mc @ BloomFilterMightContain(
g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _),
_) =>
Some(sub.plan)
case _ => None
}
}

private def invalidBucketFile(path: String): Throwable = {
new SparkException(
errorClass = "INVALID_BUCKET_FILE",
Expand Down

0 comments on commit 71ec720

Please sign in to comment.