diff --git a/.github/workflows/velox_be.yml b/.github/workflows/velox_be.yml
index 84e2e51566f6..fc860946c342 100644
--- a/.github/workflows/velox_be.yml
+++ b/.github/workflows/velox_be.yml
@@ -254,6 +254,19 @@ jobs:
--local --preset=velox --benchmark-type=h --error-on-memleak --disable-aqe --off-heap-size=10g -s=1.0 --threads=16 --iterations=1 \
&& GLUTEN_IT_JVM_ARGS=-Xmx20G sbin/gluten-it.sh queries-compare \
--local --preset=velox --benchmark-type=ds --error-on-memleak --off-heap-size=30g -s=10.0 --threads=32 --iterations=1'
+ - name: Build for Spark 3.4.1
+ run: |
+ docker exec ubuntu2204-test-$GITHUB_RUN_ID bash -c '
+ cd /opt/gluten && \
+ mvn clean install -Pspark-3.4 -Pbackends-velox -Prss -DskipTests'
+ - name: TPC-H SF1.0 && TPC-DS SF10.0 Parquet local spark3.4
+ run: |
+ docker exec ubuntu2204-test-$GITHUB_RUN_ID bash -c 'cd /opt/gluten/tools/gluten-it && \
+ mvn clean install -Pspark-3.4 \
+ && GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
+ --local --preset=velox --benchmark-type=h --error-on-memleak --disable-aqe --off-heap-size=10g -s=1.0 --threads=16 --iterations=1 \
+ && GLUTEN_IT_JVM_ARGS=-Xmx20G sbin/gluten-it.sh queries-compare \
+ --local --preset=velox --benchmark-type=ds --error-on-memleak --off-heap-size=30g -s=10.0 --threads=32 --iterations=1'
- name: Exit docker container
if: ${{ always() }}
run: |
diff --git a/dev/buildbundle-veloxbe.sh b/dev/buildbundle-veloxbe.sh
index ca78ddaaee13..3bfd6994a556 100755
--- a/dev/buildbundle-veloxbe.sh
+++ b/dev/buildbundle-veloxbe.sh
@@ -6,3 +6,4 @@ source "$BASEDIR/builddeps-veloxbe.sh"
cd $GLUTEN_DIR
mvn clean package -Pbackends-velox -Prss -Pspark-3.2 -DskipTests
mvn clean package -Pbackends-velox -Prss -Pspark-3.3 -DskipTests
+mvn clean package -Pbackends-velox -Prss -Pspark-3.4 -DskipTests
diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala
index 0deb6e5f9c21..8572e4b7f920 100644
--- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala
@@ -21,6 +21,7 @@ import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, Express
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
import io.glutenproject.extension.columnar.TransformHints
import io.glutenproject.metrics.MetricsUpdater
+import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.ExpressionNode
@@ -503,7 +504,7 @@ object FilterHandler {
batchScan.output,
scan,
leftFilters ++ newPartitionFilters,
- batchScan.table)
+ table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScan))
case _ =>
if (batchScan.runtimeFilters.isEmpty) {
throw new UnsupportedOperationException(
diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
index e91d1896d165..1d4551afa4b2 100644
--- a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
@@ -19,6 +19,7 @@ package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
+import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.spark.rdd.RDD
@@ -43,22 +44,13 @@ class BatchScanExecTransformer(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
- @transient table: Table,
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None,
+ @transient table: Table,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false)
- extends BatchScanExecShim(
- output,
- scan,
- runtimeFilters,
- keyGroupedPartitioning,
- ordering,
- table,
- commonPartitionValues,
- applyPartialClustering,
- replicatePartitions)
+ extends BatchScanExecShim(output, scan, runtimeFilters, table)
with BasicScanExecTransformer {
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@@ -134,7 +126,7 @@ class BatchScanExecTransformer(
canonicalized.output,
canonicalized.scan,
canonicalized.runtimeFilters,
- canonicalized.table
+ table = SparkShimLoader.getSparkShims.getBatchScanExecTable(canonicalized)
)
}
}
diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala
index 9cffabd0b2db..d66ca52f2431 100644
--- a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala
@@ -19,24 +19,22 @@ package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.ConverterUtils
import io.glutenproject.extension.ValidationResult
-import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater}
+import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.ReadRelNode
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PlanExpression}
import org.apache.spark.sql.connector.read.InputPartition
-import org.apache.spark.sql.execution.{FileSourceScanExecShim, InSubqueryExec, ScalarSubquery, SQLExecution}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
+import org.apache.spark.sql.execution.FileSourceScanExecShim
+import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.collection.BitSet
-import java.util.concurrent.TimeUnit.NANOSECONDS
-
import scala.collection.JavaConverters
class FileSourceScanExecTransformer(
@@ -64,10 +62,10 @@ class FileSourceScanExecTransformer(
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics: Map[String, SQLMetric] =
BackendsApiManager.getMetricsApiInstance
- .genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetrics
+ .genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetricsAlias
/** SQL metrics generated only for scans using dynamic partition pruning. */
- override protected lazy val staticMetrics =
+ private lazy val staticMetricsAlias =
if (partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) {
Map(
"staticFilesNum" -> SQLMetrics.createMetric(sparkContext, "static number of files read"),
@@ -135,91 +133,6 @@ class FileSourceScanExecTransformer(
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genFileSourceScanTransformerMetricsUpdater(metrics)
- // The codes below are copied from FileSourceScanExec in Spark,
- // all of them are private.
-
- /**
- * Send the driver-side metrics. Before calling this function, selectedPartitions has been
- * initialized. See SPARK-26327 for more details.
- */
- override protected def sendDriverMetrics(): Unit = {
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverMetrics.values.toSeq)
- }
-
- protected def setFilesNumAndSizeMetric(
- partitions: Seq[PartitionDirectory],
- static: Boolean): Unit = {
- val filesNum = partitions.map(_.files.size.toLong).sum
- val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
- if (!static || !partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) {
- driverMetrics("numFiles").set(filesNum)
- driverMetrics("filesSize").set(filesSize)
- } else {
- driverMetrics("staticFilesNum").set(filesNum)
- driverMetrics("staticFilesSize").set(filesSize)
- }
- if (relation.partitionSchema.nonEmpty) {
- driverMetrics("numPartitions").set(partitions.length)
- }
- }
-
- @transient override lazy val selectedPartitions: Array[PartitionDirectory] = {
- val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
- GlutenTimeMetric.withNanoTime {
- val ret =
- relation.location.listFiles(
- partitionFilters.filterNot(FileSourceScanExecTransformer.isDynamicPruningFilter),
- dataFilters)
- setFilesNumAndSizeMetric(ret, static = true)
- ret
- }(t => driverMetrics("metadataTime").set(NANOSECONDS.toMillis(t + optimizerMetadataTimeNs)))
- }.toArray
-
- // We can only determine the actual partitions at runtime when a dynamic partition filter is
- // present. This is because such a filter relies on information that is only available at run
- // time (for instance the keys used in the other side of a join).
- @transient override lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
- val dynamicPartitionFilters =
- partitionFilters.filter(FileSourceScanExecTransformer.isDynamicPruningFilter)
- val selected = if (dynamicPartitionFilters.nonEmpty) {
- // When it includes some DynamicPruningExpression,
- // it needs to execute InSubqueryExec first,
- // because doTransform path can't execute 'doExecuteColumnar' which will
- // execute prepare subquery first.
- dynamicPartitionFilters.foreach {
- case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
- executeInSubqueryForDynamicPruningExpression(inSubquery)
- case e: Expression =>
- e.foreach {
- case s: ScalarSubquery => s.updateResult()
- case _ =>
- }
- case _ =>
- }
- GlutenTimeMetric.withMillisTime {
- // call the file index for the files matching all filters except dynamic partition filters
- val predicate = dynamicPartitionFilters.reduce(And)
- val partitionColumns = relation.partitionSchema
- val boundPredicate = Predicate.create(
- predicate.transform {
- case a: AttributeReference =>
- val index = partitionColumns.indexWhere(a.name == _.name)
- BoundReference(index, partitionColumns(index).dataType, nullable = true)
- },
- Nil
- )
- val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
- setFilesNumAndSizeMetric(ret, static = false)
- ret
- }(t => driverMetrics("pruningTime").set(t))
- } else {
- selectedPartitions
- }
- sendDriverMetrics()
- selected
- }
-
override val nodeNamePrefix: String = "NativeFile"
override val nodeName: String = {
diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
index da46c02698ea..a388a9e02538 100644
--- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.execution.ScalarSubquery
+import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
object ExpressionMappings {
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
index ead57eafd635..d3cd79f88017 100644
--- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
@@ -22,6 +22,7 @@ import io.glutenproject.execution._
import io.glutenproject.expression.ExpressionConverter
import io.glutenproject.extension.columnar._
import io.glutenproject.metrics.GlutenTimeMetric
+import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.{ColumnarShuffleUtil, LogLevelUtil, PhysicalPlanSelector}
import org.apache.spark.api.python.EvalPythonExecTransformer
@@ -578,8 +579,12 @@ case class TransformPreOverrides(isAdaptiveContext: Boolean)
case _ =>
ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters, reuseSubquery)
}
- val transformer =
- new BatchScanExecTransformer(plan.output, plan.scan, newPartitionFilters, plan.table)
+ val transformer = new BatchScanExecTransformer(
+ plan.output,
+ plan.scan,
+ newPartitionFilters,
+ table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan))
+
val validationResult = transformer.doValidate()
if (validationResult.isValid) {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
index f4fb99a6ccae..7bede35f7abc 100644
--- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
@@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution._
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
+import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.PhysicalPlanSelector
import org.apache.spark.api.python.EvalPythonExecTransformer
@@ -333,12 +334,11 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
if (plan.runtimeFilters.nonEmpty) {
TransformHints.tagTransformable(plan)
} else {
- val transformer =
- new BatchScanExecTransformer(
- plan.output,
- plan.scan,
- plan.runtimeFilters,
- plan.table)
+ val transformer = new BatchScanExecTransformer(
+ plan.output,
+ plan.scan,
+ plan.runtimeFilters,
+ table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan))
TransformHints.tag(plan, transformer.doValidate().toTransformHint)
}
}
diff --git a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala
index eb752c3af8ad..75302ab8ab2b 100644
--- a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala
+++ b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala
@@ -20,16 +20,16 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.execution.{GlutenMergeTreePartition, GlutenPartition}
import io.glutenproject.softaffinity.SoftAffinityManager
import io.glutenproject.softaffinity.scheduler.SoftAffinityListener
+import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.plan.PlanBuilder
import org.apache.spark.SparkConf
-import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListenerExecutorRemoved}
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
-import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.test.SharedSparkSession
class SoftAffinitySuite extends QueryTest with SharedSparkSession with PredicateHelper {
@@ -43,18 +43,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
- PartitionedFile(
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath0"),
+ "fakePath0",
0,
100,
- Array("host-1", "host-2")),
- PartitionedFile(
+ Array("host-1", "host-2")
+ ),
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath1"),
+ "fakePath1",
0,
200,
- Array("host-2", "host-3"))
+ Array("host-2", "host-3")
+ )
).toArray
)
@@ -70,18 +72,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
- PartitionedFile(
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath0"),
+ "fakePath0",
0,
100,
- Array("host-1", "host-2")),
- PartitionedFile(
+ Array("host-1", "host-2")
+ ),
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath1"),
+ "fakePath1",
0,
200,
- Array("host-4", "host-5"))
+ Array("host-4", "host-5")
+ )
).toArray
)
@@ -98,18 +102,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
- PartitionedFile(
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath0"),
+ "fakePath0",
0,
100,
- Array("host-1", "host-2")),
- PartitionedFile(
+ Array("host-1", "host-2")
+ ),
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath1"),
+ "fakePath1",
0,
200,
- Array("host-5", "host-6"))
+ Array("host-5", "host-6")
+ )
).toArray
)
@@ -138,18 +144,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
- PartitionedFile(
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath0"),
+ "fakePath0",
0,
100,
- Array("host-1", "host-2")),
- PartitionedFile(
+ Array("host-1", "host-2")
+ ),
+ SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
- SparkPath.fromPathString("fakePath1"),
+ "fakePath1",
0,
200,
- Array("host-5", "host-6"))
+ Array("host-5", "host-6")
+ )
).toArray
)
diff --git a/gluten-core/src/main/scala/io/glutenproject/metrics/GlutenTimeMetric.scala b/shims/common/src/main/scala/io/glutenproject/metrics/GlutenTimeMetric.scala
similarity index 100%
rename from gluten-core/src/main/scala/io/glutenproject/metrics/GlutenTimeMetric.scala
rename to shims/common/src/main/scala/io/glutenproject/metrics/GlutenTimeMetric.scala
diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
index 5a00c6e52bba..de99e7efb44c 100644
--- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
@@ -21,12 +21,15 @@ import io.glutenproject.expression.Sig
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -67,4 +70,15 @@ trait SparkShims {
def filesGroupedToBuckets(
selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]]
+
+ // Spark3.4 new add table parameter in BatchScanExec.
+ def getBatchScanExecTable(batchScan: BatchScanExec): Table
+
+ // The PartitionedFile API changed in spark 3.4
+ def generatePartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long,
+ @transient locations: Array[String] = Array.empty): PartitionedFile
}
diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/Arm.scala b/shims/common/src/main/scala/io/glutenproject/utils/Arm.scala
similarity index 100%
rename from gluten-core/src/main/scala/io/glutenproject/utils/Arm.scala
rename to shims/common/src/main/scala/io/glutenproject/utils/Arm.scala
diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
index db3cdf56bbc6..1cacc2f75be9 100644
--- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
+++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
@@ -24,9 +24,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.types.StructType
@@ -88,4 +90,14 @@ class Spark32Shims extends SparkShims {
.getOrElse(throw new IllegalStateException(s"Invalid bucket file ${f.filePath}"))
}
}
+
+ override def getBatchScanExecTable(batchScan: BatchScanExec): Table = null
+
+ override def generatePartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long,
+ @transient locations: Array[String] = Array.empty): PartitionedFile =
+ PartitionedFile(partitionValues, filePath, start, length, locations)
}
diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala
similarity index 52%
rename from gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala
rename to shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala
index 972a8bdaa539..bc7cacf7995e 100644
--- a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala
+++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala
@@ -14,20 +14,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.utils
+package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, IntegerLiteral}
-object DataSourceStrategyUtil {
-
- /**
- * Translates a runtime filter into a data source filter.
- *
- * Runtime filters usually contain a subquery that must be evaluated before the translation. If
- * the underlying subquery hasn't completed yet, this method will throw an exception.
- */
- def translateRuntimeFilter(expr: Expression): Option[Predicate] =
- DataSourceV2Strategy.translateRuntimeFilterV2(expr)
+/**
+ * A logical offset, which may removing a specified number of rows from the beginning of the output
+ * of child logical plan.
+ */
+case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = {
+ import scala.math.max
+ offsetExpr match {
+ case IntegerLiteral(offset) => child.maxRows.map(x => max(x - offset, 0))
+ case _ => None
+ }
+ }
+ override protected def withNewChildInternal(newChild: LogicalPlan): Offset =
+ copy(child = newChild)
}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
index c721a4beda01..9e32b35b8f3b 100644
--- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
+++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
@@ -16,13 +16,19 @@
*/
package org.apache.spark.sql.execution
+import io.glutenproject.metrics.GlutenTimeMetric
+
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.execution.datasources.HadoopFsRelation
-import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.BitSet
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import scala.collection.mutable
+
class FileSourceScanExecShim(
@transient relation: HadoopFsRelation,
output: Seq[Attribute],
@@ -60,4 +66,95 @@ class FileSourceScanExecShim(
def hasMetadataColumns: Boolean = false
def hasFieldIds: Boolean = false
+
+ // The codes below are copied from FileSourceScanExec in Spark,
+ // all of them are private.
+ protected lazy val driverMetrics: mutable.HashMap[String, Long] = mutable.HashMap.empty
+
+ /**
+ * Send the driver-side metrics. Before calling this function, selectedPartitions has been
+ * initialized. See SPARK-26327 for more details.
+ */
+ protected def sendDriverMetrics(): Unit = {
+ driverMetrics.foreach(e => metrics(e._1).add(e._2))
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(
+ sparkContext,
+ executionId,
+ metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
+ }
+
+ protected def setFilesNumAndSizeMetric(
+ partitions: Seq[PartitionDirectory],
+ static: Boolean): Unit = {
+ val filesNum = partitions.map(_.files.size.toLong).sum
+ val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
+ if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
+ driverMetrics("numFiles") = filesNum
+ driverMetrics("filesSize") = filesSize
+ } else {
+ driverMetrics("staticFilesNum") = filesNum
+ driverMetrics("staticFilesSize") = filesSize
+ }
+ if (relation.partitionSchema.nonEmpty) {
+ driverMetrics("numPartitions") = partitions.length
+ }
+ }
+
+ @transient override lazy val selectedPartitions: Array[PartitionDirectory] = {
+ val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
+ GlutenTimeMetric.withNanoTime {
+ val ret =
+ relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters)
+ setFilesNumAndSizeMetric(ret, static = true)
+ ret
+ }(t => driverMetrics("metadataTime") = NANOSECONDS.toMillis(t + optimizerMetadataTimeNs))
+ }.toArray
+
+ private def isDynamicPruningFilter(e: Expression): Boolean =
+ e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+
+ // We can only determine the actual partitions at runtime when a dynamic partition filter is
+ // present. This is because such a filter relies on information that is only available at run
+ // time (for instance the keys used in the other side of a join).
+ @transient lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
+ val dynamicPartitionFilters =
+ partitionFilters.filter(isDynamicPruningFilter)
+ val selected = if (dynamicPartitionFilters.nonEmpty) {
+ // When it includes some DynamicPruningExpression,
+ // it needs to execute InSubqueryExec first,
+ // because doTransform path can't execute 'doExecuteColumnar' which will
+ // execute prepare subquery first.
+ dynamicPartitionFilters.foreach {
+ case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
+ if (inSubquery.values().isEmpty) inSubquery.updateResult()
+ case e: Expression =>
+ e.foreach {
+ case s: ScalarSubquery => s.updateResult()
+ case _ =>
+ }
+ case _ =>
+ }
+ GlutenTimeMetric.withMillisTime {
+ // call the file index for the files matching all filters except dynamic partition filters
+ val predicate = dynamicPartitionFilters.reduce(And)
+ val partitionColumns = relation.partitionSchema
+ val boundPredicate = Predicate.create(
+ predicate.transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ },
+ Nil
+ )
+ val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
+ setFilesNumAndSizeMetric(ret, static = false)
+ ret
+ }(t => driverMetrics("pruningTime") = t)
+ } else {
+ selectedPartitions
+ }
+ sendDriverMetrics()
+ selected
+ }
}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index faea3aade19d..b867c71cefda 100644
--- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -28,7 +29,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
class BatchScanExecShim(
output: Seq[AttributeReference],
@transient scan: Scan,
- runtimeFilters: Seq[Expression])
+ runtimeFilters: Seq[Expression],
+ @transient table: Table)
extends BatchScanExec(output, scan, runtimeFilters) {
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@@ -82,4 +84,8 @@ class BatchScanExecShim(
}
@transient lazy val pushedAggregate: Option[Aggregation] = None
+
+ final override protected def otherCopyArgs: Seq[AnyRef] = {
+ output :: scan :: runtimeFilters :: Nil
+ }
}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
deleted file mode 100644
index 621486d743e5..000000000000
--- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScan
-import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScan(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- readDataSchema: StructType,
- readPartitionSchema: StructType,
- pushedFilters: Array[Filter],
- options: CaseInsensitiveStringMap,
- partitionFilters: Seq[Expression] = Seq.empty,
- dataFilters: Seq[Expression] = Seq.empty)
- extends FileScan {
- override def createReaderFactory(): PartitionReaderFactory = {
- null
- }
-
- override def withFilters(
- partitionFilters: Seq[Expression],
- dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
-}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
deleted file mode 100644
index dda9aeca7a47..000000000000
--- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
-import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScanBuilder(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- schema: StructType,
- dataSchema: StructType,
- options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
- with SupportsPushDownFilters {
-
- private lazy val pushedArrowFilters: Array[Filter] = {
- filters // todo filter validation & pushdown
- }
- private var filters: Array[Filter] = Array.empty
-
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
- this.filters = filters
- this.filters
- }
-
- override def build(): Scan = {
- DwrfScan(
- sparkSession,
- fileIndex,
- readDataSchema(),
- readPartitionSchema(),
- pushedFilters,
- options)
- }
-
- override def pushedFilters: Array[Filter] = pushedArrowFilters
-}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
new file mode 100644
index 000000000000..0dbdac871a68
--- /dev/null
+++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.stat
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import java.util.Locale
+
+/**
+ * This file is copied from Spark
+ *
+ * The df.describe() and df.summary() issues are fixed by
+ * https://github.com/apache/spark/pull/40914. We picked it into Gluten to fix the describe and
+ * summary issue. And this file can be removed after upgrading spark version to 3.4 or higher
+ * version.
+ */
+object StatFunctions extends Logging {
+
+ /**
+ * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass.
+ *
+ * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+ * elements and if we request the quantile at probability `p` up to error `err`, then the
+ * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is close
+ * to (p * N). More precisely,
+ *
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
+ *
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile
+ * Summaries by Greenwald and Khanna.
+ *
+ * @param df
+ * the dataframe
+ * @param cols
+ * numerical columns of the dataframe
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal 0). If set to zero, the exact
+ * quantiles are computed, which could be very expensive. Note that values greater than 1 are
+ * accepted but give the same result as 1.
+ *
+ * @return
+ * for each column, returns the requested approximations
+ *
+ * @note
+ * null and NaN values will be ignored in numerical columns before calculation. For a column
+ * only containing null or NaN values, an empty array is returned.
+ */
+ def multipleApproxQuantiles(
+ df: DataFrame,
+ cols: Seq[String],
+ probabilities: Seq[Double],
+ relativeError: Double): Seq[Seq[Double]] = {
+ require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError")
+ val columns: Seq[Column] = cols.map {
+ colName =>
+ val field = df.resolve(colName)
+ require(
+ field.dataType.isInstanceOf[NumericType],
+ s"Quantile calculation for column $colName with data type ${field.dataType}" +
+ " is not supported.")
+ Column(Cast(Column(colName).expr, DoubleType))
+ }
+ val emptySummaries = Array.fill(cols.size)(
+ new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError))
+
+ // Note that it works more or less by accident as `rdd.aggregate` is not a pure function:
+ // this function returns the same array as given in the input (because `aggregate` reuses
+ // the same argument).
+ def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = {
+ var i = 0
+ while (i < summaries.length) {
+ if (!row.isNullAt(i)) {
+ val v = row.getDouble(i)
+ if (!v.isNaN) summaries(i) = summaries(i).insert(v)
+ }
+ i += 1
+ }
+ summaries
+ }
+
+ def merge(
+ sum1: Array[QuantileSummaries],
+ sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = {
+ sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) }
+ }
+ val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
+
+ summaries.map {
+ summary =>
+ summary.query(probabilities) match {
+ case Some(q) => q
+ case None => Seq()
+ }
+ }
+ }
+
+ /** Calculate the Pearson Correlation Coefficient for the given columns */
+ def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols, "correlation")
+ counts.Ck / math.sqrt(counts.MkX * counts.MkY)
+ }
+
+ /** Helper class to simplify tracking and merging counts. */
+ private class CovarianceCounter extends Serializable {
+ var xAvg = 0.0 // the mean of all examples seen so far in col1
+ var yAvg = 0.0 // the mean of all examples seen so far in col2
+ var Ck = 0.0 // the co-moment after k examples
+ var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
+ var count = 0L // count of observed examples
+ // add an example to the calculation
+ def add(x: Double, y: Double): this.type = {
+ val deltaX = x - xAvg
+ val deltaY = y - yAvg
+ count += 1
+ xAvg += deltaX / count
+ yAvg += deltaY / count
+ Ck += deltaX * (y - yAvg)
+ MkX += deltaX * (x - xAvg)
+ MkY += deltaY * (y - yAvg)
+ this
+ }
+ // merge counters from other partitions. Formula can be found at:
+ // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ def merge(other: CovarianceCounter): this.type = {
+ if (other.count > 0) {
+ val totalCount = count + other.count
+ val deltaX = xAvg - other.xAvg
+ val deltaY = yAvg - other.yAvg
+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
+ xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
+ yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
+ count = totalCount
+ }
+ this
+ }
+ // return the sample covariance for the observed examples
+ def cov: Double = Ck / (count - 1)
+ }
+
+ private def collectStatisticalData(
+ df: DataFrame,
+ cols: Seq[String],
+ functionName: String): CovarianceCounter = {
+ require(
+ cols.length == 2,
+ s"Currently $functionName calculation is supported " +
+ "between two columns.")
+ cols.map(name => (name, df.resolve(name))).foreach {
+ case (name, data) =>
+ require(
+ data.dataType.isInstanceOf[NumericType],
+ s"Currently $functionName calculation " +
+ s"for columns with dataType ${data.dataType.catalogString} not supported."
+ )
+ }
+ val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
+ df.select(columns: _*)
+ .queryExecution
+ .toRdd
+ .treeAggregate(new CovarianceCounter)(
+ seqOp = (counter, row) => {
+ counter.add(row.getDouble(0), row.getDouble(1))
+ },
+ combOp = (baseCounter, other) => {
+ baseCounter.merge(other)
+ })
+ }
+
+ /**
+ * Calculate the covariance of two numerical columns of a DataFrame.
+ * @param df
+ * The DataFrame
+ * @param cols
+ * the column names
+ * @return
+ * the covariance of the two columns.
+ */
+ def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols, "covariance")
+ counts.cov
+ }
+
+ /** Generate a table of frequencies for the elements of two columns. */
+ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
+ val tableName = s"${col1}_$col2"
+ val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
+ if (counts.length == 1e6.toInt) {
+ logWarning(
+ "The maximum limit of 1e6 pairs have been collected, which may not be all of " +
+ "the pairs. Please try reducing the amount of distinct items in your columns.")
+ }
+ def cleanElement(element: Any): String = {
+ if (element == null) "null" else element.toString
+ }
+ // get the distinct sorted values of column 2, so that we can make them the column names
+ val distinctCol2: Map[Any, Int] =
+ counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap
+ val columnSize = distinctCol2.size
+ require(
+ columnSize < 1e4,
+ s"The number of distinct values for $col2, can't " +
+ s"exceed 1e4. Currently $columnSize")
+ val table = counts
+ .groupBy(_.get(0))
+ .map {
+ case (col1Item, rows) =>
+ val countsRow = new GenericInternalRow(columnSize + 1)
+ rows.foreach {
+ (row: Row) =>
+ // row.get(0) is column 1
+ // row.get(1) is column 2
+ // row.get(2) is the frequency
+ val columnIndex = distinctCol2(cleanElement(row.get(1)))
+ countsRow.setLong(columnIndex + 1, row.getLong(2))
+ }
+ // the value of col1 is the first value, the rest are the counts
+ countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
+ countsRow
+ }
+ .toSeq
+ // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
+ // special keywords and `.`, wrap the column names in ``.
+ def cleanColumnName(name: String): String = {
+ name.replace("`", "")
+ }
+ // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
+ // SPARK-8681. We need to explicitly sort by the column index and assign the column names.
+ val headerNames = distinctCol2.toSeq.sortBy(_._2).map {
+ r => StructField(cleanColumnName(r._1.toString), LongType)
+ }
+ val schema = StructType(StructField(tableName, StringType) +: headerNames)
+
+ Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+ }
+
+ /** Calculate selected summary statistics for a dataset */
+ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
+
+ val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+ val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
+
+ val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map {
+ p =>
+ try {
+ p.stripSuffix("%").toDouble / 100.0
+ } catch {
+ case e: NumberFormatException =>
+ throw QueryExecutionErrors.cannotParseStatisticAsPercentileError(p, e)
+ }
+ }
+ require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
+
+ def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) {
+ TryCast(e, DoubleType)
+ } else {
+ e
+ }
+ var percentileIndex = 0
+ val statisticFns = selectedStatistics.map {
+ stats =>
+ if (stats.endsWith("%")) {
+ val index = percentileIndex
+ percentileIndex += 1
+ (child: Expression) =>
+ GetArrayItem(
+ new ApproximatePercentile(
+ castAsDoubleIfNecessary(child),
+ Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)))
+ .toAggregateExpression(),
+ Literal(index)
+ )
+ } else {
+ stats.toLowerCase(Locale.ROOT) match {
+ case "count" => (child: Expression) => Count(child).toAggregateExpression()
+ case "count_distinct" =>
+ (child: Expression) => Count(child).toAggregateExpression(isDistinct = true)
+ case "approx_count_distinct" =>
+ (child: Expression) => HyperLogLogPlusPlus(child).toAggregateExpression()
+ case "mean" =>
+ (child: Expression) => Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
+ case "stddev" =>
+ (child: Expression) =>
+ StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
+ case "min" => (child: Expression) => Min(child).toAggregateExpression()
+ case "max" => (child: Expression) => Max(child).toAggregateExpression()
+ case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
+ }
+ }
+ }
+
+ val selectedCols = ds.logicalPlan.output
+ .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
+
+ val aggExprs = statisticFns.flatMap {
+ func => selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
+ }
+
+ // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
+ lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.map(_.copy()).collect().head
+
+ // We will have one row for each selected statistic in the result.
+ val result = Array.fill[InternalRow](selectedStatistics.length) {
+ // each row has the statistic name, and statistic values of each selected column.
+ new GenericInternalRow(selectedCols.length + 1)
+ }
+
+ var rowIndex = 0
+ while (rowIndex < result.length) {
+ val statsName = selectedStatistics(rowIndex)
+ result(rowIndex).update(0, UTF8String.fromString(statsName))
+ for (colIndex <- selectedCols.indices) {
+ val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
+ result(rowIndex).update(colIndex + 1, statsValue)
+ }
+ rowIndex += 1
+ }
+
+ // All columns are string type
+ val output = AttributeReference("summary", StringType)() +:
+ selectedCols.map(c => AttributeReference(c.name, StringType)())
+
+ Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+ }
+}
diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
index 06325480f244..b593b6da7066 100644
--- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
+++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
@@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.types.StructType
@@ -113,6 +115,16 @@ class Spark33Shims extends SparkShims {
}
}
+ override def getBatchScanExecTable(batchScan: BatchScanExec): Table = null
+
+ override def generatePartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long,
+ @transient locations: Array[String] = Array.empty): PartitionedFile =
+ PartitionedFile(partitionValues, filePath, start, length, locations)
+
private def invalidBucketFile(path: String): Throwable = {
new SparkException(
errorClass = "INVALID_BUCKET_FILE",
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala
new file mode 100644
index 000000000000..bc7cacf7995e
--- /dev/null
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, IntegerLiteral}
+
+/**
+ * A logical offset, which may removing a specified number of rows from the beginning of the output
+ * of child logical plan.
+ */
+case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = {
+ import scala.math.max
+ offsetExpr match {
+ case IntegerLiteral(offset) => child.maxRows.map(x => max(x - offset, 0))
+ case _ => None
+ }
+ }
+ override protected def withNewChildInternal(newChild: LogicalPlan): Offset =
+ copy(child = newChild)
+}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
index 88bd259c6e10..cfbf91bc2188 100644
--- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
@@ -16,14 +16,20 @@
*/
package org.apache.spark.sql.execution
+import io.glutenproject.metrics.GlutenTimeMetric
+
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
-import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.BitSet
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import scala.collection.mutable
+
class FileSourceScanExecShim(
@transient relation: HadoopFsRelation,
output: Seq[Attribute],
@@ -61,4 +67,95 @@ class FileSourceScanExecShim(
def hasMetadataColumns: Boolean = metadataColumns.nonEmpty
def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema)
+
+ // The codes below are copied from FileSourceScanExec in Spark,
+ // all of them are private.
+ protected lazy val driverMetrics: mutable.HashMap[String, Long] = mutable.HashMap.empty
+
+ /**
+ * Send the driver-side metrics. Before calling this function, selectedPartitions has been
+ * initialized. See SPARK-26327 for more details.
+ */
+ protected def sendDriverMetrics(): Unit = {
+ driverMetrics.foreach(e => metrics(e._1).add(e._2))
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(
+ sparkContext,
+ executionId,
+ metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
+ }
+
+ protected def setFilesNumAndSizeMetric(
+ partitions: Seq[PartitionDirectory],
+ static: Boolean): Unit = {
+ val filesNum = partitions.map(_.files.size.toLong).sum
+ val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
+ if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
+ driverMetrics("numFiles") = filesNum
+ driverMetrics("filesSize") = filesSize
+ } else {
+ driverMetrics("staticFilesNum") = filesNum
+ driverMetrics("staticFilesSize") = filesSize
+ }
+ if (relation.partitionSchema.nonEmpty) {
+ driverMetrics("numPartitions") = partitions.length
+ }
+ }
+
+ @transient override lazy val selectedPartitions: Array[PartitionDirectory] = {
+ val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
+ GlutenTimeMetric.withNanoTime {
+ val ret =
+ relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters)
+ setFilesNumAndSizeMetric(ret, static = true)
+ ret
+ }(t => driverMetrics("metadataTime") = NANOSECONDS.toMillis(t + optimizerMetadataTimeNs))
+ }.toArray
+
+ private def isDynamicPruningFilter(e: Expression): Boolean =
+ e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+
+ // We can only determine the actual partitions at runtime when a dynamic partition filter is
+ // present. This is because such a filter relies on information that is only available at run
+ // time (for instance the keys used in the other side of a join).
+ @transient lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
+ val dynamicPartitionFilters =
+ partitionFilters.filter(isDynamicPruningFilter)
+ val selected = if (dynamicPartitionFilters.nonEmpty) {
+ // When it includes some DynamicPruningExpression,
+ // it needs to execute InSubqueryExec first,
+ // because doTransform path can't execute 'doExecuteColumnar' which will
+ // execute prepare subquery first.
+ dynamicPartitionFilters.foreach {
+ case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
+ if (inSubquery.values().isEmpty) inSubquery.updateResult()
+ case e: Expression =>
+ e.foreach {
+ case s: ScalarSubquery => s.updateResult()
+ case _ =>
+ }
+ case _ =>
+ }
+ GlutenTimeMetric.withMillisTime {
+ // call the file index for the files matching all filters except dynamic partition filters
+ val predicate = dynamicPartitionFilters.reduce(And)
+ val partitionColumns = relation.partitionSchema
+ val boundPredicate = Predicate.create(
+ predicate.transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ },
+ Nil
+ )
+ val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
+ setFilesNumAndSizeMetric(ret, static = false)
+ ret
+ }(t => driverMetrics("pruningTime") = t)
+ } else {
+ selectedPartitions
+ }
+ sendDriverMetrics()
+ selected
+ }
}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index 21806a96f3d3..0a81d1b28922 100644
--- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.catalyst.util.InternalRowSet
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -32,7 +33,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
class BatchScanExecShim(
output: Seq[AttributeReference],
@transient scan: Scan,
- runtimeFilters: Seq[Expression])
+ runtimeFilters: Seq[Expression],
+ @transient table: Table)
extends BatchScanExec(output, scan, runtimeFilters) {
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@@ -116,4 +118,8 @@ class BatchScanExecShim(
case _ => None
}
}
+
+ final override protected def otherCopyArgs: Seq[AnyRef] = {
+ output :: scan :: runtimeFilters :: Nil
+ }
}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
deleted file mode 100644
index 6536f8081474..000000000000
--- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScan
-import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScan(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- readDataSchema: StructType,
- readPartitionSchema: StructType,
- pushedFilters: Array[Filter],
- options: CaseInsensitiveStringMap,
- partitionFilters: Seq[Expression] = Seq.empty,
- dataFilters: Seq[Expression] = Seq.empty)
- extends FileScan {
- override def createReaderFactory(): PartitionReaderFactory = {
- null
- }
-
- override def dataSchema: StructType = readDataSchema
-}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
deleted file mode 100644
index 475b18b68531..000000000000
--- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.Scan
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScanBuilder(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- schema: StructType,
- dataSchema: StructType,
- options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
-
- override def build(): Scan = {
- DwrfScan(
- sparkSession,
- fileIndex,
- readDataSchema(),
- readPartitionSchema(),
- pushedDataFilters,
- options)
- }
-
-}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
new file mode 100644
index 000000000000..08ba7680ca70
--- /dev/null
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.stat
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import java.util.Locale
+
+/**
+ * This file is copied from Spark
+ *
+ * The df.describe() and df.summary() issues are fixed by
+ * https://github.com/apache/spark/pull/40914. We picked it into Gluten to fix the describe and
+ * summary issue. And this file can be removed after upgrading spark version to 3.4 or higher
+ * version.
+ */
+object StatFunctions extends Logging {
+
+ /**
+ * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass.
+ *
+ * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+ * elements and if we request the quantile at probability `p` up to error `err`, then the
+ * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is close
+ * to (p * N). More precisely,
+ *
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
+ *
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile
+ * Summaries by Greenwald and Khanna.
+ *
+ * @param df
+ * the dataframe
+ * @param cols
+ * numerical columns of the dataframe
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal 0). If set to zero, the exact
+ * quantiles are computed, which could be very expensive. Note that values greater than 1 are
+ * accepted but give the same result as 1.
+ * @return
+ * for each column, returns the requested approximations
+ * @note
+ * null and NaN values will be ignored in numerical columns before calculation. For a column
+ * only containing null or NaN values, an empty array is returned.
+ */
+ def multipleApproxQuantiles(
+ df: DataFrame,
+ cols: Seq[String],
+ probabilities: Seq[Double],
+ relativeError: Double): Seq[Seq[Double]] = {
+ require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError")
+ val columns: Seq[Column] = cols.map {
+ colName =>
+ val field = df.resolve(colName)
+ require(
+ field.dataType.isInstanceOf[NumericType],
+ s"Quantile calculation for column $colName with data type ${field.dataType}" +
+ " is not supported.")
+ Column(Cast(Column(colName).expr, DoubleType))
+ }
+ val emptySummaries = Array.fill(cols.size)(
+ new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError))
+
+ // Note that it works more or less by accident as `rdd.aggregate` is not a pure function:
+ // this function returns the same array as given in the input (because `aggregate` reuses
+ // the same argument).
+ def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = {
+ var i = 0
+ while (i < summaries.length) {
+ if (!row.isNullAt(i)) {
+ val v = row.getDouble(i)
+ if (!v.isNaN) summaries(i) = summaries(i).insert(v)
+ }
+ i += 1
+ }
+ summaries
+ }
+
+ def merge(
+ sum1: Array[QuantileSummaries],
+ sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = {
+ sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) }
+ }
+
+ val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
+
+ summaries.map {
+ summary =>
+ summary.query(probabilities) match {
+ case Some(q) => q
+ case None => Seq()
+ }
+ }
+ }
+
+ /** Calculate the Pearson Correlation Coefficient for the given columns */
+ def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols, "correlation")
+ counts.Ck / math.sqrt(counts.MkX * counts.MkY)
+ }
+
+ /** Helper class to simplify tracking and merging counts. */
+ private class CovarianceCounter extends Serializable {
+ var xAvg = 0.0 // the mean of all examples seen so far in col1
+ var yAvg = 0.0 // the mean of all examples seen so far in col2
+ var Ck = 0.0 // the co-moment after k examples
+ var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
+ var count = 0L // count of observed examples
+
+ // add an example to the calculation
+ def add(x: Double, y: Double): this.type = {
+ val deltaX = x - xAvg
+ val deltaY = y - yAvg
+ count += 1
+ xAvg += deltaX / count
+ yAvg += deltaY / count
+ Ck += deltaX * (y - yAvg)
+ MkX += deltaX * (x - xAvg)
+ MkY += deltaY * (y - yAvg)
+ this
+ }
+
+ // merge counters from other partitions. Formula can be found at:
+ // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ def merge(other: CovarianceCounter): this.type = {
+ if (other.count > 0) {
+ val totalCount = count + other.count
+ val deltaX = xAvg - other.xAvg
+ val deltaY = yAvg - other.yAvg
+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
+ xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
+ yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
+ count = totalCount
+ }
+ this
+ }
+
+ // return the sample covariance for the observed examples
+ def cov: Double = Ck / (count - 1)
+ }
+
+ private def collectStatisticalData(
+ df: DataFrame,
+ cols: Seq[String],
+ functionName: String): CovarianceCounter = {
+ require(
+ cols.length == 2,
+ s"Currently $functionName calculation is supported " +
+ "between two columns.")
+ cols.map(name => (name, df.resolve(name))).foreach {
+ case (name, data) =>
+ require(
+ data.dataType.isInstanceOf[NumericType],
+ s"Currently $functionName calculation " +
+ s"for columns with dataType ${data.dataType.catalogString} not supported."
+ )
+ }
+ val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
+ df.select(columns: _*)
+ .queryExecution
+ .toRdd
+ .treeAggregate(new CovarianceCounter)(
+ seqOp = (counter, row) => {
+ counter.add(row.getDouble(0), row.getDouble(1))
+ },
+ combOp = (baseCounter, other) => {
+ baseCounter.merge(other)
+ })
+ }
+
+ /**
+ * Calculate the covariance of two numerical columns of a DataFrame.
+ *
+ * @param df
+ * The DataFrame
+ * @param cols
+ * the column names
+ * @return
+ * the covariance of the two columns.
+ */
+ def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+ val counts = collectStatisticalData(df, cols, "covariance")
+ counts.cov
+ }
+
+ /** Generate a table of frequencies for the elements of two columns. */
+ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
+ val tableName = s"${col1}_$col2"
+ val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
+ if (counts.length == 1e6.toInt) {
+ logWarning(
+ "The maximum limit of 1e6 pairs have been collected, which may not be all of " +
+ "the pairs. Please try reducing the amount of distinct items in your columns.")
+ }
+
+ def cleanElement(element: Any): String = {
+ if (element == null) "null" else element.toString
+ }
+
+ // get the distinct sorted values of column 2, so that we can make them the column names
+ val distinctCol2: Map[Any, Int] =
+ counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap
+ val columnSize = distinctCol2.size
+ require(
+ columnSize < 1e4,
+ s"The number of distinct values for $col2, can't " +
+ s"exceed 1e4. Currently $columnSize")
+ val table = counts
+ .groupBy(_.get(0))
+ .map {
+ case (col1Item, rows) =>
+ val countsRow = new GenericInternalRow(columnSize + 1)
+ rows.foreach {
+ (row: Row) =>
+ // row.get(0) is column 1
+ // row.get(1) is column 2
+ // row.get(2) is the frequency
+ val columnIndex = distinctCol2(cleanElement(row.get(1)))
+ countsRow.setLong(columnIndex + 1, row.getLong(2))
+ }
+ // the value of col1 is the first value, the rest are the counts
+ countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
+ countsRow
+ }
+ .toSeq
+
+ // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
+ // special keywords and `.`, wrap the column names in ``.
+ def cleanColumnName(name: String): String = {
+ name.replace("`", "")
+ }
+
+ // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
+ // SPARK-8681. We need to explicitly sort by the column index and assign the column names.
+ val headerNames = distinctCol2.toSeq.sortBy(_._2).map {
+ r => StructField(cleanColumnName(r._1.toString), LongType)
+ }
+ val schema = StructType(StructField(tableName, StringType) +: headerNames)
+
+ Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+ }
+
+ /** Calculate selected summary statistics for a dataset */
+ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
+
+ val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+ val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
+
+ val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map {
+ p =>
+ try {
+ p.stripSuffix("%").toDouble / 100.0
+ } catch {
+ case e: NumberFormatException =>
+ throw QueryExecutionErrors.cannotParseStatisticAsPercentileError(p, e)
+ }
+ }
+ require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
+
+ def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) {
+ TryCast(e, DoubleType)
+ } else {
+ e
+ }
+
+ var percentileIndex = 0
+ val statisticFns = selectedStatistics.map {
+ stats =>
+ if (stats.endsWith("%")) {
+ val index = percentileIndex
+ percentileIndex += 1
+ (child: Expression) =>
+ GetArrayItem(
+ new ApproximatePercentile(
+ castAsDoubleIfNecessary(child),
+ Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)))
+ .toAggregateExpression(),
+ Literal(index)
+ )
+ } else {
+ stats.toLowerCase(Locale.ROOT) match {
+ case "count" => (child: Expression) => Count(child).toAggregateExpression()
+ case "count_distinct" =>
+ (child: Expression) => Count(child).toAggregateExpression(isDistinct = true)
+ case "approx_count_distinct" =>
+ (child: Expression) => HyperLogLogPlusPlus(child).toAggregateExpression()
+ case "mean" =>
+ (child: Expression) => Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
+ case "stddev" =>
+ (child: Expression) =>
+ StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
+ case "min" => (child: Expression) => Min(child).toAggregateExpression()
+ case "max" => (child: Expression) => Max(child).toAggregateExpression()
+ case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
+ }
+ }
+ }
+
+ val selectedCols = ds.logicalPlan.output
+ .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
+
+ val aggExprs = statisticFns.flatMap {
+ func => selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
+ }
+
+ // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
+ lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.map(_.copy()).collect().head
+
+ // We will have one row for each selected statistic in the result.
+ val result = Array.fill[InternalRow](selectedStatistics.length) {
+ // each row has the statistic name, and statistic values of each selected column.
+ new GenericInternalRow(selectedCols.length + 1)
+ }
+
+ var rowIndex = 0
+ while (rowIndex < result.length) {
+ val statsName = selectedStatistics(rowIndex)
+ result(rowIndex).update(0, UTF8String.fromString(statsName))
+ for (colIndex <- selectedCols.indices) {
+ val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
+ result(rowIndex).update(colIndex + 1, statsValue)
+ }
+ rowIndex += 1
+ }
+
+ // All columns are string type
+ val output = AttributeReference("summary", StringType)() +:
+ selectedCols.map(c => AttributeReference(c.name, StringType)())
+
+ Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+ }
+}
diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
index 59cea64751a4..fbaf04cbc7c2 100644
--- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
+++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
@@ -21,16 +21,19 @@ import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}
import org.apache.spark.SparkException
+import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanLike, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.types.StructType
@@ -114,6 +117,16 @@ class Spark34Shims extends SparkShims {
}
}
+ override def getBatchScanExecTable(batchScan: BatchScanExec): Table = batchScan.table
+
+ override def generatePartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long,
+ @transient locations: Array[String] = Array.empty): PartitionedFile =
+ PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations)
+
private def invalidBucketFile(path: String): Throwable = {
new SparkException(
errorClass = "INVALID_BUCKET_FILE",
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
index 8c4de5cb1f07..6230cedbd13b 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
+++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
@@ -16,9 +16,11 @@
*/
package org.apache.spark.sql.execution
+import io.glutenproject.metrics.GlutenTimeMetric
+
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
@@ -61,4 +63,66 @@ class FileSourceScanExecShim(
def hasMetadataColumns: Boolean = fileConstantMetadataColumns.nonEmpty
def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema)
+
+ private def isDynamicPruningFilter(e: Expression): Boolean =
+ e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+
+ protected def setFilesNumAndSizeMetric(
+ partitions: Seq[PartitionDirectory],
+ static: Boolean): Unit = {
+ val filesNum = partitions.map(_.files.size.toLong).sum
+ val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
+ if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
+ driverMetrics("numFiles").set(filesNum)
+ driverMetrics("filesSize").set(filesSize)
+ } else {
+ driverMetrics("staticFilesNum").set(filesNum)
+ driverMetrics("staticFilesSize").set(filesSize)
+ }
+ if (relation.partitionSchema.nonEmpty) {
+ driverMetrics("numPartitions").set(partitions.length)
+ }
+ }
+
+ @transient override protected lazy val dynamicallySelectedPartitions
+ : Array[PartitionDirectory] = {
+ val dynamicPartitionFilters =
+ partitionFilters.filter(isDynamicPruningFilter)
+ val selected = if (dynamicPartitionFilters.nonEmpty) {
+ // When it includes some DynamicPruningExpression,
+ // it needs to execute InSubqueryExec first,
+ // because doTransform path can't execute 'doExecuteColumnar' which will
+ // execute prepare subquery first.
+ dynamicPartitionFilters.foreach {
+ case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
+ if (inSubquery.values().isEmpty) inSubquery.updateResult()
+ case e: Expression =>
+ e.foreach {
+ case s: ScalarSubquery => s.updateResult()
+ case _ =>
+ }
+ case _ =>
+ }
+ GlutenTimeMetric.withMillisTime {
+ // call the file index for the files matching all filters except dynamic partition filters
+ val predicate = dynamicPartitionFilters.reduce(And)
+ val partitionColumns = relation.partitionSchema
+ val boundPredicate = Predicate.create(
+ predicate.transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ },
+ Nil
+ )
+ val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
+ setFilesNumAndSizeMetric(ret, static = false)
+ ret
+ }(t => driverMetrics("pruningTime").set(t))
+ } else {
+ selectedPartitions
+ }
+ sendDriverMetrics()
+ selected
+ }
}
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 3bde7ce155ef..65215f379d77 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -26,12 +26,14 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.hadoop.conf.Configuration
@@ -51,6 +53,28 @@ object FileFormatWriter extends Logging {
customPartitionLocations: Map[TablePartitionSpec, String],
outputColumns: Seq[Attribute])
+ case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression {
+ override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v
+ override def nullable: Boolean = true
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(
+ ctx,
+ ev,
+ c => {
+ s"""if ($c.numBytes() == 0) {
+ | ${ev.isNull} = true;
+ | ${ev.value} = null;
+ |} else {
+ | ${ev.value} = $c;
+ |}""".stripMargin
+ }
+ )
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): Empty2Null =
+ copy(child = newChild)
+ }
+
/** Describes how concurrent output writers should be executed. */
case class ConcurrentOutputWriterSpec(
maxWriters: Int,
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated
index aaae179971b3..d43331d57c47 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated
+++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated
@@ -14,34 +14,46 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql.execution.datasources.v2
+import com.google.common.base.Objects
+
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
-import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowSet}
+import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read._
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-
-import java.util.Objects
+import org.apache.spark.sql.internal.SQLConf
-/** Physical plan node for scanning a batch of data from a data source v2. */
+/**
+ * Physical plan node for scanning a batch of data from a data source v2.
+ */
case class BatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
- keyGroupedPartitioning: Option[Seq[Expression]] = None)
- extends DataSourceV2ScanExecBase {
+ keyGroupedPartitioning: Option[Seq[Expression]] = None,
+ ordering: Option[Seq[SortOrder]] = None,
+ @transient table: Table,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+ applyPartialClustering: Boolean = false,
+ replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase {
- @transient lazy val batch = scan.toBatch
+ @transient lazy val batch = if (scan == null) null else scan.toBatch
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: BatchScanExec =>
- this.batch == other.batch && this.runtimeFilters == other.runtimeFilters
+ this.batch != null && this.batch == other.batch &&
+ this.runtimeFilters == other.runtimeFilters &&
+ this.commonPartitionValues == other.commonPartitionValues &&
+ this.replicatePartitions == other.replicatePartitions &&
+ this.applyPartialClustering == other.applyPartialClustering
case _ =>
false
}
@@ -52,7 +64,7 @@ case class BatchScanExec(
@transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
- case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
+ case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
case _ => None
}
@@ -60,7 +72,7 @@ case class BatchScanExec(
val originalPartitioning = outputPartitioning
// the cast is safe as runtime filters are only assigned if the scan can be filtered
- val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
+ val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
filterableScan.filter(dataSourceFilters.toArray)
// call toBatch again to get filtered partitions
@@ -69,28 +81,28 @@ case class BatchScanExec(
originalPartitioning match {
case p: KeyGroupedPartitioning =>
if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
- throw new SparkException(
- "Data source must have preserved the original partitioning " +
+ throw new SparkException("Data source must have preserved the original partitioning " +
"during runtime filtering: not all partitions implement HasPartitionKey after " +
"filtering")
}
-
- val newRows = new InternalRowSet(p.expressions.map(_.dataType))
- newRows ++= newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey())
- val oldRows = p.partitionValuesOpt.get
-
- if (oldRows.size != newRows.size) {
- throw new SparkException(
- "Data source must have preserved the original partitioning " +
- "during runtime filtering: the number of unique partition values obtained " +
- s"through HasPartitionKey changed: before ${oldRows.size}, after ${newRows.size}")
+ val newPartitionValues = newPartitions.map(partition =>
+ InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions))
+ .toSet
+ val oldPartitionValues = p.partitionValues
+ .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet
+ // We require the new number of partition values to be equal or less than the old number
+ // of partition values here. In the case of less than, empty partitions will be added for
+ // those missing values that are not present in the new input partitions.
+ if (oldPartitionValues.size < newPartitionValues.size) {
+ throw new SparkException("During runtime filtering, data source must either report " +
+ "the same number of partition values, or a subset of partition values from the " +
+ s"original. Before: ${oldPartitionValues.size} partition values. " +
+ s"After: ${newPartitionValues.size} partition values")
}
- if (!oldRows.forall(newRows.contains)) {
- throw new SparkException(
- "Data source must have preserved the original partitioning " +
- "during runtime filtering: the number of unique partition values obtained " +
- s"through HasPartitionKey remain the same but do not exactly match")
+ if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+ throw new SparkException("During runtime filtering, data source must not report new " +
+ "partition values that are not present in the original partitioning.")
}
groupPartitions(newPartitions).get.map(_._2)
@@ -105,20 +117,109 @@ case class BatchScanExec(
}
}
+ override def outputPartitioning: Partitioning = {
+ super.outputPartitioning match {
+ case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
+ // We allow duplicated partition values if
+ // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+ val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+ Seq.fill(numSplits)(partValue)
+ }
+ k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
+ case p => p
+ }
+ }
+
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
override lazy val inputRDD: RDD[InternalRow] = {
- if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
+ val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
+ var finalPartitions = filteredPartitions
+
+ outputPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ if (conf.v2BucketingPushPartValuesEnabled &&
+ conf.v2BucketingPartiallyClusteredDistributionEnabled) {
+ assert(filteredPartitions.forall(_.size == 1),
+ "Expect partitions to be not grouped when " +
+ s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+ "is enabled")
+
+ val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get
+
+ // This means the input partitions are not grouped by partition values. We'll need to
+ // check `groupByPartitionValues` and decide whether to group and replicate splits
+ // within a partition.
+ if (commonPartitionValues.isDefined && applyPartialClustering) {
+ // A mapping from the common partition values to how many splits the partition
+ // should contain. Note this no longer maintain the partition key ordering.
+ val commonPartValuesMap = commonPartitionValues
+ .get
+ .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
+ .toMap
+ val nestGroupedPartitions = groupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue, p.expressions))
+ assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (replicatePartitions) {
+ // We need to also replicate partitions according to the other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the side with partially
+ // clustered distribution. Because of dynamic filtering, we'll need to check if
+ // the final number of splits of a partition is smaller than the original
+ // number, and fill with empty splits if so. This is necessary so that both
+ // sides of a join will have the same number of partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, p.expressions), newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ val partitionMapping = groupedPartitions.map { case (row, parts) =>
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map { partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+ }
+ }
+ } else {
+ val partitionMapping = finalPartitions.map { parts =>
+ val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map { partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+ }
+ }
+
+ case _ =>
+ }
+
new DataSourceRDD(
- sparkContext,
- filteredPartitions,
- readerFactory,
- supportsColumnar,
- customMetrics)
+ sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics)
}
+ postDriverMetrics()
+ rdd
}
override def doCanonicalize(): BatchScanExec = {
@@ -126,8 +227,7 @@ case class BatchScanExec(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
runtimeFilters = QueryPlan.normalizePredicates(
runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
- output)
- )
+ output))
}
override def simpleString(maxFields: Int): String = {
@@ -137,6 +237,7 @@ case class BatchScanExec(
redact(result)
}
- // Set to false to disable BatchScan's columnar output.
- override def supportsColumnar: Boolean = false
+ override def nodeName: String = {
+ s"BatchScan ${table.name()}".trim
+ }
}
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index 68ea957c1ef6..b7f781a0aeb8 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -40,28 +40,12 @@ class BatchScanExecShim(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
- keyGroupedPartitioning: Option[Seq[Expression]],
- ordering: Option[Seq[SortOrder]],
- @transient table: Table,
- commonPartitionValues: Option[Seq[(InternalRow, Int)]],
- applyPartialClustering: Boolean,
- replicatePartitions: Boolean)
- extends BatchScanExec(
- output,
- scan,
- runtimeFilters,
- keyGroupedPartitioning,
- ordering,
- table,
- commonPartitionValues,
- applyPartialClustering,
- replicatePartitions) {
+ @transient table: Table)
+ extends BatchScanExec(output, scan, runtimeFilters, table = table) {
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics: Map[String, SQLMetric] = Map()
- override def supportsColumnar(): Boolean = GlutenConfig.getConf.enableColumnarIterator
-
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
throw new UnsupportedOperationException("Need to implement this method")
}
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark33Scan.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark34Scan.scala
similarity index 100%
rename from shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark33Scan.scala
rename to shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark34Scan.scala
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
deleted file mode 100644
index 6536f8081474..000000000000
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScan
-import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScan(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- readDataSchema: StructType,
- readPartitionSchema: StructType,
- pushedFilters: Array[Filter],
- options: CaseInsensitiveStringMap,
- partitionFilters: Seq[Expression] = Seq.empty,
- dataFilters: Seq[Expression] = Seq.empty)
- extends FileScan {
- override def createReaderFactory(): PartitionReaderFactory = {
- null
- }
-
- override def dataSchema: StructType = readDataSchema
-}
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
deleted file mode 100644
index 475b18b68531..000000000000
--- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2.velox
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.Scan
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
-import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
-case class DwrfScanBuilder(
- sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
- schema: StructType,
- dataSchema: StructType,
- options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
-
- override def build(): Scan = {
- DwrfScan(
- sparkSession,
- fileIndex,
- readDataSchema(),
- readPartitionSchema(),
- pushedDataFilters,
- options)
- }
-
-}
diff --git a/substrait/substrait-spark/pom.xml b/substrait/substrait-spark/pom.xml
index 5a7f37cb0ee5..51fe2fd5b108 100644
--- a/substrait/substrait-spark/pom.xml
+++ b/substrait/substrait-spark/pom.xml
@@ -15,6 +15,12 @@
Gluten Substrait Spark
+
+ io.glutenproject
+ ${sparkshim.artifactId}
+ ${project.version}
+ compile
+
org.apache.spark
spark-sql_${scala.binary.version}
diff --git a/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala
index 2feedf0f7eff..2789d84f8f8e 100644
--- a/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala
+++ b/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala
@@ -134,7 +134,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral)
case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a)
case a: Alias => translateUp(a.child)
-// case p: PromotePrecision => translateUp(p.child)
+ case p: PromotePrecision => translateUp(p.child)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case scalar @ ScalarFunction(children) =>
Util
diff --git a/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala
index 4d06687cf741..f8cf3767938a 100644
--- a/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala
+++ b/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala
@@ -255,6 +255,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
throw new UnsupportedOperationException(
s"Unable to convert the plan to a substrait plan: $plan")
}
+
private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = {
toSubstraitExp(e, output)
}
@@ -335,9 +336,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = {
expr match {
- case s @ ScalarSubquery(childPlan, outerAttrs, _, joinCond, _, _)
- if outerAttrs.isEmpty && joinCond.isEmpty =>
- val rel = toSubstraitRel.visit(childPlan)
+ case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty =>
+ val rel = toSubstraitRel.visit(s.plan)
Some(
SExpression.ScalarSubquery.builder
.input(rel)
diff --git a/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
index 836a087f1f53..09b3ecc426c6 100644
--- a/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
+++ b/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
@@ -69,4 +69,6 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] {
override def visitSort(sort: Sort): Rel = t(sort)
override def visitWithCTE(p: WithCTE): Rel = t(p)
+
+ def visitOffset(p: Offset): Rel = t(p)
}
diff --git a/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
index 345cb215f4ac..081d6f93f545 100644
--- a/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
+++ b/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
@@ -70,5 +70,7 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] {
override def visitWithCTE(p: WithCTE): Rel = t(p)
+ def visitOffset(p: Offset): Rel = t(p)
+
override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p)
}
diff --git a/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
index 8962190171d6..ec3ee78e8c47 100644
--- a/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
+++ b/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala
@@ -71,5 +71,6 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] {
override def visitWithCTE(p: WithCTE): Rel = t(p)
override def visitOffset(p: Offset): Rel = t(p)
+
override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p)
}
diff --git a/tools/gluten-it/pom.xml b/tools/gluten-it/pom.xml
index e7c46ee97789..38d4260151d0 100644
--- a/tools/gluten-it/pom.xml
+++ b/tools/gluten-it/pom.xml
@@ -19,6 +19,7 @@
2.12.15
3.2.2
3.3.1
+ 3.4.1
${spark32.version}
2.12
3
@@ -86,12 +87,6 @@
${spark.version}
provided
test-jar
-
-
- org.apache.arrow
- *
-
-
@@ -150,5 +145,30 @@
+
+ spark-3.4
+
+ ${spark34.version}
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ test-jar
+
+
+