From dc879f469a68fee86e8c562bb98db1341e434035 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Tue, 10 Oct 2023 10:40:52 +0000 Subject: [PATCH] Support spark 3.4 shim layer in gluten --- .../BasicPhysicalOperatorTransformer.scala | 3 +- .../execution/BatchScanExecTransformer.scala | 16 +- .../FileSourceScanExecTransformer.scala | 101 +---- .../extension/ColumnarOverrides.scala | 9 +- .../columnar/TransformHintRule.scala | 12 +- .../softaffinity/SoftAffinitySuite.scala | 60 +-- .../spark/sql/GlutenSubquerySuite.scala | 2 +- .../spark/sql/GlutenSubquerySuite.scala | 2 +- .../glutenproject/sql/shims/SparkShims.scala | 16 +- .../sql/shims/spark32/Spark32Shims.scala | 12 + .../sql/catalyst/expressions/Empty2Null.scala | 49 +++ .../sql/catalyst/plans/logical/Offset.scala | 31 +- .../execution/FileSourceScanExecShim.scala | 74 +++- .../datasources/v2/BatchScanExecShim.scala | 4 +- .../datasources/v2/velox/DwrfScan.scala | 46 --- .../v2/velox/DwrfScanBuilder.scala | 57 --- .../sql/execution/stat/StatFunctions.scala | 356 +++++++++++++++++ .../sql/shims/spark33/Spark33Shims.scala | 12 + .../sql/catalyst/expressions/Empty2Null.scala | 49 +++ .../sql/catalyst/plans/logical/Offset.scala | 36 ++ .../execution/FileSourceScanExecShim.scala | 74 +++- .../datasources/v2/BatchScanExecShim.scala | 4 +- .../datasources/v2/velox/DwrfScan.scala | 43 --- .../v2/velox/DwrfScanBuilder.scala | 44 --- .../sql/execution/stat/StatFunctions.scala | 364 ++++++++++++++++++ .../sql/shims/spark34/Spark34Shims.scala | 13 + .../execution/FileSourceScanExecShim.scala | 7 +- .../datasources/v2/BatchScanExecShim.scala | 18 +- .../{Spark33Scan.scala => Spark34Scan.scala} | 0 .../datasources/v2/velox/DwrfScan.scala | 43 --- .../v2/velox/DwrfScanBuilder.scala | 44 --- substrait/substrait-spark/pom.xml | 6 + .../expression/ToSubstraitExpression.scala | 2 +- .../spark/logical/ToSubstraitRel.scala | 6 +- .../logical/AbstractLogicalPlanVisitor.scala | 2 + .../logical/AbstractLogicalPlanVisitor.scala | 2 + .../logical/AbstractLogicalPlanVisitor.scala | 1 + 37 files changed, 1155 insertions(+), 465 deletions(-) create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala rename gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala => shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala (52%) delete mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala delete mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala rename shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/{Spark33Scan.scala => Spark34Scan.scala} (100%) delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala index 614568f9d9c0b..e43d1497acf6d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala @@ -22,6 +22,7 @@ import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, Express import io.glutenproject.extension.{GlutenPlan, ValidationResult} import io.glutenproject.extension.columnar.TransformHints import io.glutenproject.metrics.MetricsUpdater +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.expression.ExpressionNode @@ -562,7 +563,7 @@ object FilterHandler { batchScan.output, scan, leftFilters ++ newPartitionFilters, - batchScan.table) + table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScan)) case _ => if (batchScan.runtimeFilters.isEmpty) { throw new UnsupportedOperationException( diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala index d195d335364eb..91703d2f51d18 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.rdd.RDD @@ -45,22 +46,13 @@ class BatchScanExecTransformer( output: Seq[AttributeReference], @transient scan: Scan, runtimeFilters: Seq[Expression], - @transient table: Table, keyGroupedPartitioning: Option[Seq[Expression]] = None, ordering: Option[Seq[SortOrder]] = None, + @transient table: Table, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) - extends BatchScanExecShim( - output, - scan, - runtimeFilters, - keyGroupedPartitioning, - ordering, - table, - commonPartitionValues, - applyPartialClustering, - replicatePartitions) + extends BatchScanExecShim(output, scan, runtimeFilters, table) with BasicScanExecTransformer { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. @@ -154,7 +146,7 @@ class BatchScanExecTransformer( canonicalized.output, canonicalized.scan, canonicalized.runtimeFilters, - canonicalized.table + table = SparkShimLoader.getSparkShims.getBatchScanExecTable(canonicalized) ) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala index b9e7033fe592f..16c584843695c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala @@ -20,24 +20,22 @@ import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression.ConverterUtils import io.glutenproject.extension.ValidationResult -import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater} +import io.glutenproject.metrics.MetricsUpdater import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.substrait.rel.ReadRelNode import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PlanExpression} import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.execution.{FileSourceScanExecShim, InSubqueryExec, ScalarSubquery, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.execution.{FileSourceScanExecShim, SparkPlan} +import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.collection.BitSet -import java.util.concurrent.TimeUnit.NANOSECONDS - import scala.collection.JavaConverters class FileSourceScanExecTransformer( @@ -65,10 +63,10 @@ class FileSourceScanExecTransformer( // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. @transient override lazy val metrics: Map[String, SQLMetric] = BackendsApiManager.getMetricsApiInstance - .genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetrics + .genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetricsAlias /** SQL metrics generated only for scans using dynamic partition pruning. */ - override protected lazy val staticMetrics = + private lazy val staticMetricsAlias = if (partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) { Map( "staticFilesNum" -> SQLMetrics.createMetric(sparkContext, "static number of files read"), @@ -93,7 +91,7 @@ class FileSourceScanExecTransformer( override def getPartitions: Seq[InputPartition] = BackendsApiManager.getTransformerApiInstance.genInputPartitionSeq( relation, - dynamicallySelectedPartitions, + dynamicallySelectedPartitionsAlias, output, optionalBucketSet, optionalNumCoalescedBuckets, @@ -160,91 +158,6 @@ class FileSourceScanExecTransformer( override def metricsUpdater(): MetricsUpdater = BackendsApiManager.getMetricsApiInstance.genFileSourceScanTransformerMetricsUpdater(metrics) - // The codes below are copied from FileSourceScanExec in Spark, - // all of them are private. - - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - override protected def sendDriverMetrics(): Unit = { - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverMetrics.values.toSeq) - } - - protected def setFilesNumAndSizeMetric( - partitions: Seq[PartitionDirectory], - static: Boolean): Unit = { - val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = partitions.map(_.files.map(_.getLen).sum).sum - if (!static || !partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) { - driverMetrics("numFiles").set(filesNum) - driverMetrics("filesSize").set(filesSize) - } else { - driverMetrics("staticFilesNum").set(filesNum) - driverMetrics("staticFilesSize").set(filesSize) - } - if (relation.partitionSchema.nonEmpty) { - driverMetrics("numPartitions").set(partitions.length) - } - } - - @transient override lazy val selectedPartitions: Array[PartitionDirectory] = { - val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) - GlutenTimeMetric.withNanoTime { - val ret = - relation.location.listFiles( - partitionFilters.filterNot(FileSourceScanExecTransformer.isDynamicPruningFilter), - dataFilters) - setFilesNumAndSizeMetric(ret, static = true) - ret - }(t => driverMetrics("metadataTime").set(NANOSECONDS.toMillis(t + optimizerMetadataTimeNs))) - }.toArray - - // We can only determine the actual partitions at runtime when a dynamic partition filter is - // present. This is because such a filter relies on information that is only available at run - // time (for instance the keys used in the other side of a join). - @transient override lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { - val dynamicPartitionFilters = - partitionFilters.filter(FileSourceScanExecTransformer.isDynamicPruningFilter) - val selected = if (dynamicPartitionFilters.nonEmpty) { - // When it includes some DynamicPruningExpression, - // it needs to execute InSubqueryExec first, - // because doTransform path can't execute 'doExecuteColumnar' which will - // execute prepare subquery first. - dynamicPartitionFilters.foreach { - case DynamicPruningExpression(inSubquery: InSubqueryExec) => - executeInSubqueryForDynamicPruningExpression(inSubquery) - case e: Expression => - e.foreach { - case s: ScalarSubquery => s.updateResult() - case _ => - } - case _ => - } - GlutenTimeMetric.withMillisTime { - // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) - val partitionColumns = relation.partitionSchema - val boundPredicate = Predicate.create( - predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil - ) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) - setFilesNumAndSizeMetric(ret, static = false) - ret - }(t => driverMetrics("pruningTime").set(t)) - } else { - selectedPartitions - } - sendDriverMetrics() - selected - } - override val nodeNamePrefix: String = "NativeFile" override val nodeName: String = { diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 8b0dd1e1365ef..ddfa0fba7aabc 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -22,6 +22,7 @@ import io.glutenproject.execution._ import io.glutenproject.expression.ExpressionConverter import io.glutenproject.extension.columnar._ import io.glutenproject.metrics.GlutenTimeMetric +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.utils.{ColumnarShuffleUtil, LogLevelUtil, PhysicalPlanSelector} import org.apache.spark.api.python.EvalPythonExecTransformer @@ -578,8 +579,12 @@ case class TransformPreOverrides(isAdaptiveContext: Boolean) case _ => ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters, reuseSubquery) } - val transformer = - new BatchScanExecTransformer(plan.output, plan.scan, newPartitionFilters, plan.table) + val transformer = new BatchScanExecTransformer( + plan.output, + plan.scan, + newPartitionFilters, + table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan)) + val validationResult = transformer.doValidate() if (validationResult.isValid) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala index f4fb99a6ccae8..7bede35f7abcf 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.execution._ import io.glutenproject.extension.{GlutenPlan, ValidationResult} +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.utils.PhysicalPlanSelector import org.apache.spark.api.python.EvalPythonExecTransformer @@ -333,12 +334,11 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { if (plan.runtimeFilters.nonEmpty) { TransformHints.tagTransformable(plan) } else { - val transformer = - new BatchScanExecTransformer( - plan.output, - plan.scan, - plan.runtimeFilters, - plan.table) + val transformer = new BatchScanExecTransformer( + plan.output, + plan.scan, + plan.runtimeFilters, + table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan)) TransformHints.tag(plan, transformer.doValidate().toTransformHint) } } diff --git a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala index eb752c3af8ad2..75302ab8ab2bb 100644 --- a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala +++ b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala @@ -20,16 +20,16 @@ import io.glutenproject.GlutenConfig import io.glutenproject.execution.{GlutenMergeTreePartition, GlutenPartition} import io.glutenproject.softaffinity.SoftAffinityManager import io.glutenproject.softaffinity.scheduler.SoftAffinityListener +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.substrait.plan.PlanBuilder import org.apache.spark.SparkConf -import org.apache.spark.paths.SparkPath import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListenerExecutorRemoved} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PredicateHelper -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.test.SharedSparkSession class SoftAffinitySuite extends QueryTest with SharedSparkSession with PredicateHelper { @@ -43,18 +43,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate val partition = FilePartition( 0, Seq( - PartitionedFile( + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath0"), + "fakePath0", 0, 100, - Array("host-1", "host-2")), - PartitionedFile( + Array("host-1", "host-2") + ), + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath1"), + "fakePath1", 0, 200, - Array("host-2", "host-3")) + Array("host-2", "host-3") + ) ).toArray ) @@ -70,18 +72,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate val partition = FilePartition( 0, Seq( - PartitionedFile( + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath0"), + "fakePath0", 0, 100, - Array("host-1", "host-2")), - PartitionedFile( + Array("host-1", "host-2") + ), + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath1"), + "fakePath1", 0, 200, - Array("host-4", "host-5")) + Array("host-4", "host-5") + ) ).toArray ) @@ -98,18 +102,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate val partition = FilePartition( 0, Seq( - PartitionedFile( + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath0"), + "fakePath0", 0, 100, - Array("host-1", "host-2")), - PartitionedFile( + Array("host-1", "host-2") + ), + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath1"), + "fakePath1", 0, 200, - Array("host-5", "host-6")) + Array("host-5", "host-6") + ) ).toArray ) @@ -138,18 +144,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate val partition = FilePartition( 0, Seq( - PartitionedFile( + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath0"), + "fakePath0", 0, 100, - Array("host-1", "host-2")), - PartitionedFile( + Array("host-1", "host-2") + ), + SparkShimLoader.getSparkShims.generatePartitionedFile( InternalRow.empty, - SparkPath.fromPathString("fakePath1"), + "fakePath1", 0, 200, - Array("host-5", "host-6")) + Array("host-5", "host-6") + ) ).toArray ) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala index 6251397f51b57..30634bf25c1bf 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala @@ -50,7 +50,7 @@ class GlutenSubquerySuite extends SubquerySuite with GlutenSQLTestsTrait { case t: WholeStageTransformer => t } match { case Some(WholeStageTransformer(fs: FileSourceScanExecTransformer, _)) => - fs.dynamicallySelectedPartitions + fs.dynamicallySelectedPartitionsAlias .exists(_.files.exists(_.getPath.toString.contains("p=0"))) case _ => false }) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala index 6251397f51b57..30634bf25c1bf 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenSubquerySuite.scala @@ -50,7 +50,7 @@ class GlutenSubquerySuite extends SubquerySuite with GlutenSQLTestsTrait { case t: WholeStageTransformer => t } match { case Some(WholeStageTransformer(fs: FileSourceScanExecTransformer, _)) => - fs.dynamicallySelectedPartitions + fs.dynamicallySelectedPartitionsAlias .exists(_.files.exists(_.getPath.toString.contains("p=0"))) case _ => false }) diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index 5a00c6e52bba0..de99e7efb44ce 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -21,12 +21,15 @@ import io.glutenproject.expression.Sig import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -67,4 +70,15 @@ trait SparkShims { def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] + + // Spark3.4 new add table parameter in BatchScanExec. + def getBatchScanExecTable(batchScan: BatchScanExec): Table + + // The PartitionedFile API changed in spark 3.4 + def generatePartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long, + @transient locations: Array[String] = Array.empty): PartitionedFile } diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala index db3cdf56bbc65..1cacc2f75be92 100644 --- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala @@ -24,9 +24,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.types.StructType @@ -88,4 +90,14 @@ class Spark32Shims extends SparkShims { .getOrElse(throw new IllegalStateException(s"Invalid bucket file ${f.filePath}")) } } + + override def getBatchScanExecTable(batchScan: BatchScanExec): Table = null + + override def generatePartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long, + @transient locations: Array[String] = Array.empty): PartitionedFile = + PartitionedFile(partitionValues, filePath, start, length, locations) } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala new file mode 100644 index 0000000000000..241159ea0e251 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.{Expression, String2StringExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A internal function that converts the empty string to null for partition values. This function + * should be only used in V1Writes. + */ +case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v + + override def nullable: Boolean = true + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + c => { + s"""if ($c.numBytes() == 0) { + | ${ev.isNull} = true; + | ${ev.value} = null; + |} else { + | ${ev.value} = $c; + |}""".stripMargin + } + ) + } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) +} diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala similarity index 52% rename from gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala rename to shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala index 972a8bdaa5392..bc7cacf7995e6 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DataSourceStrategyUtil.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala @@ -14,20 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.utils +package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, IntegerLiteral} -object DataSourceStrategyUtil { - - /** - * Translates a runtime filter into a data source filter. - * - * Runtime filters usually contain a subquery that must be evaluated before the translation. If - * the underlying subquery hasn't completed yet, this method will throw an exception. - */ - def translateRuntimeFilter(expr: Expression): Option[Predicate] = - DataSourceV2Strategy.translateRuntimeFilterV2(expr) +/** + * A logical offset, which may removing a specified number of rows from the beginning of the output + * of child logical plan. + */ +case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + import scala.math.max + offsetExpr match { + case IntegerLiteral(offset) => child.maxRows.map(x => max(x - offset, 0)) + case _ => None + } + } + override protected def withNewChildInternal(newChild: LogicalPlan): Offset = + copy(child = newChild) } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index c721a4beda01b..9c1170187c202 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.execution.datasources.HadoopFsRelation -import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Expression, PlanExpression, Predicate} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.BitSet +import scala.collection.mutable + class FileSourceScanExecShim( @transient relation: HadoopFsRelation, output: Seq[Attribute], @@ -60,4 +62,70 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = false def hasFieldIds: Boolean = false + + // The codes below are copied from FileSourceExec in Spark, + // all of them are private. + private lazy val driverMetrics: mutable.HashMap[String, Long] = mutable.HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient lazy val dynamicallySelectedPartitionsAlias: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil + ) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 112b5832d95f5..ed995c77c9caa 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -21,6 +21,7 @@ import io.glutenproject.GlutenConfig import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, Scan, SupportsRuntimeFiltering} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -30,7 +31,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class BatchScanExecShim( output: Seq[AttributeReference], @transient scan: Scan, - runtimeFilters: Seq[Expression]) + runtimeFilters: Seq[Expression], + @transient table: Table) extends BatchScanExec(output, scan, runtimeFilters) { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala deleted file mode 100644 index 621486d743e57..0000000000000 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) - extends FileScan { - override def createReaderFactory(): PartitionReaderFactory = { - null - } - - override def withFilters( - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) -} diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala deleted file mode 100644 index dda9aeca7a47e..0000000000000 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScanBuilder( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - schema: StructType, - dataSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) - with SupportsPushDownFilters { - - private lazy val pushedArrowFilters: Array[Filter] = { - filters // todo filter validation & pushdown - } - private var filters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } - - override def build(): Scan = { - DwrfScan( - sparkSession, - fileIndex, - readDataSchema(), - readPartitionSchema(), - pushedFilters, - options) - } - - override def pushedFilters: Array[Filter] = pushedArrowFilters -} diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala new file mode 100644 index 0000000000000..0dbdac871a68e --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.stat + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.util.Locale + +/** + * This file is copied from Spark + * + * The df.describe() and df.summary() issues are fixed by + * https://github.com/apache/spark/pull/40914. We picked it into Gluten to fix the describe and + * summary issue. And this file can be removed after upgrading spark version to 3.4 or higher + * version. + */ +object StatFunctions extends Logging { + + /** + * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass. + * + * The result of this algorithm has the following deterministic bound: If the DataFrame has N + * elements and if we request the quantile at probability `p` up to error `err`, then the + * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is close + * to (p * N). More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile + * Summaries by Greenwald and Khanna. + * + * @param df + * the dataframe + * @param cols + * numerical columns of the dataframe + * @param probabilities + * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the + * minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError + * The relative target precision to achieve (greater than or equal 0). If set to zero, the exact + * quantiles are computed, which could be very expensive. Note that values greater than 1 are + * accepted but give the same result as 1. + * + * @return + * for each column, returns the requested approximations + * + * @note + * null and NaN values will be ignored in numerical columns before calculation. For a column + * only containing null or NaN values, an empty array is returned. + */ + def multipleApproxQuantiles( + df: DataFrame, + cols: Seq[String], + probabilities: Seq[Double], + relativeError: Double): Seq[Seq[Double]] = { + require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError") + val columns: Seq[Column] = cols.map { + colName => + val field = df.resolve(colName) + require( + field.dataType.isInstanceOf[NumericType], + s"Quantile calculation for column $colName with data type ${field.dataType}" + + " is not supported.") + Column(Cast(Column(colName).expr, DoubleType)) + } + val emptySummaries = Array.fill(cols.size)( + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) + + // Note that it works more or less by accident as `rdd.aggregate` is not a pure function: + // this function returns the same array as given in the input (because `aggregate` reuses + // the same argument). + def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { + var i = 0 + while (i < summaries.length) { + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } + i += 1 + } + summaries + } + + def merge( + sum1: Array[QuantileSummaries], + sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = { + sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) } + } + val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge) + + summaries.map { + summary => + summary.query(probabilities) match { + case Some(q) => q + case None => Seq() + } + } + } + + /** Calculate the Pearson Correlation Coefficient for the given columns */ + def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols, "correlation") + counts.Ck / math.sqrt(counts.MkX * counts.MkY) + } + + /** Helper class to simplify tracking and merging counts. */ + private class CovarianceCounter extends Serializable { + var xAvg = 0.0 // the mean of all examples seen so far in col1 + var yAvg = 0.0 // the mean of all examples seen so far in col2 + var Ck = 0.0 // the co-moment after k examples + var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 + var MkY = 0.0 // sum of squares of differences from the (current) mean for col2 + var count = 0L // count of observed examples + // add an example to the calculation + def add(x: Double, y: Double): this.type = { + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + this + } + // merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + def merge(other: CovarianceCounter): this.type = { + if (other.count > 0) { + val totalCount = count + other.count + val deltaX = xAvg - other.xAvg + val deltaY = yAvg - other.yAvg + Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count + MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count + count = totalCount + } + this + } + // return the sample covariance for the observed examples + def cov: Double = Ck / (count - 1) + } + + private def collectStatisticalData( + df: DataFrame, + cols: Seq[String], + functionName: String): CovarianceCounter = { + require( + cols.length == 2, + s"Currently $functionName calculation is supported " + + "between two columns.") + cols.map(name => (name, df.resolve(name))).foreach { + case (name, data) => + require( + data.dataType.isInstanceOf[NumericType], + s"Currently $functionName calculation " + + s"for columns with dataType ${data.dataType.catalogString} not supported." + ) + } + val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) + df.select(columns: _*) + .queryExecution + .toRdd + .treeAggregate(new CovarianceCounter)( + seqOp = (counter, row) => { + counter.add(row.getDouble(0), row.getDouble(1)) + }, + combOp = (baseCounter, other) => { + baseCounter.merge(other) + }) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df + * The DataFrame + * @param cols + * the column names + * @return + * the covariance of the two columns. + */ + def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols, "covariance") + counts.cov + } + + /** Generate a table of frequencies for the elements of two columns. */ + def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + val tableName = s"${col1}_$col2" + val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) + if (counts.length == 1e6.toInt) { + logWarning( + "The maximum limit of 1e6 pairs have been collected, which may not be all of " + + "the pairs. Please try reducing the amount of distinct items in your columns.") + } + def cleanElement(element: Any): String = { + if (element == null) "null" else element.toString + } + // get the distinct sorted values of column 2, so that we can make them the column names + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap + val columnSize = distinctCol2.size + require( + columnSize < 1e4, + s"The number of distinct values for $col2, can't " + + s"exceed 1e4. Currently $columnSize") + val table = counts + .groupBy(_.get(0)) + .map { + case (col1Item, rows) => + val countsRow = new GenericInternalRow(columnSize + 1) + rows.foreach { + (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(2) is the frequency + val columnIndex = distinctCol2(cleanElement(row.get(1))) + countsRow.setLong(columnIndex + 1, row.getLong(2)) + } + // the value of col1 is the first value, the rest are the counts + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) + countsRow + } + .toSeq + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { + r => StructField(cleanColumnName(r._1.toString), LongType) + } + val schema = StructType(StructField(tableName, StringType) +: headerNames) + + Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + } + + /** Calculate selected summary statistics for a dataset */ + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { + p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw QueryExecutionErrors.cannotParseStatisticAsPercentileError(p, e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + + def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { + TryCast(e, DoubleType) + } else { + e + } + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { + stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile( + castAsDoubleIfNecessary(child), + Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) + .toAggregateExpression(), + Literal(index) + ) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "count_distinct" => + (child: Expression) => Count(child).toAggregateExpression(isDistinct = true) + case "approx_count_distinct" => + (child: Expression) => HyperLogLogPlusPlus(child).toAggregateExpression() + case "mean" => + (child: Expression) => Average(castAsDoubleIfNecessary(child)).toAggregateExpression() + case "stddev" => + (child: Expression) => + StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + } + } + } + + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) + + val aggExprs = statisticFns.flatMap { + func => selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } + + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.map(_.copy()).collect().head + + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } + + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) + } + rowIndex += 1 + } + + // All columns are string type + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } +} diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 06325480f2445..b593b6da70666 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.types.StructType @@ -113,6 +115,16 @@ class Spark33Shims extends SparkShims { } } + override def getBatchScanExecTable(batchScan: BatchScanExec): Table = null + + override def generatePartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long, + @transient locations: Array[String] = Array.empty): PartitionedFile = + PartitionedFile(partitionValues, filePath, start, length, locations) + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE", diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala new file mode 100644 index 0000000000000..241159ea0e251 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/Empty2Null.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.{Expression, String2StringExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A internal function that converts the empty string to null for partition values. This function + * should be only used in V1Writes. + */ +case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v + + override def nullable: Boolean = true + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + c => { + s"""if ($c.numBytes() == 0) { + | ${ev.isNull} = true; + | ${ev.value} = null; + |} else { + | ${ev.value} = $c; + |}""".stripMargin + } + ) + } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala new file mode 100644 index 0000000000000..bc7cacf7995e6 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Offset.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, IntegerLiteral} + +/** + * A logical offset, which may removing a specified number of rows from the beginning of the output + * of child logical plan. + */ +case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + import scala.math.max + offsetExpr match { + case IntegerLiteral(offset) => child.maxRows.map(x => max(x - offset, 0)) + case _ => None + } + } + override protected def withNewChildInternal(newChild: LogicalPlan): Offset = + copy(child = newChild) +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index 88bd259c6e104..d86d5f4668b9f 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.execution.datasources.HadoopFsRelation +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Expression, PlanExpression, Predicate} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils -import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.BitSet +import scala.collection.mutable + class FileSourceScanExecShim( @transient relation: HadoopFsRelation, output: Seq[Attribute], @@ -61,4 +63,70 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = metadataColumns.nonEmpty def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema) + + // The codes below are copied from FileSourceExec in Spark, + // all of them are private. + private lazy val driverMetrics: mutable.HashMap[String, Long] = mutable.HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient lazy val dynamicallySelectedPartitionsAlias: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil + ) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } } diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 0fe83094a780e..331e16df380a2 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.catalyst.util.InternalRowSet +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan, SupportsRuntimeFiltering} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -34,7 +35,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class BatchScanExecShim( output: Seq[AttributeReference], @transient scan: Scan, - runtimeFilters: Seq[Expression]) + runtimeFilters: Seq[Expression], + @transient table: Table) extends BatchScanExec(output, scan, runtimeFilters) { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala deleted file mode 100644 index 6536f80814744..0000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) - extends FileScan { - override def createReaderFactory(): PartitionReaderFactory = { - null - } - - override def dataSchema: StructType = readDataSchema -} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala deleted file mode 100644 index 475b18b68531b..0000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScanBuilder( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - schema: StructType, - dataSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { - - override def build(): Scan = { - DwrfScan( - sparkSession, - fileIndex, - readDataSchema(), - readPartitionSchema(), - pushedDataFilters, - options) - } - -} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala new file mode 100644 index 0000000000000..08ba7680ca701 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.stat + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.util.Locale + +/** + * This file is copied from Spark + * + * The df.describe() and df.summary() issues are fixed by + * https://github.com/apache/spark/pull/40914. We picked it into Gluten to fix the describe and + * summary issue. And this file can be removed after upgrading spark version to 3.4 or higher + * version. + */ +object StatFunctions extends Logging { + + /** + * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass. + * + * The result of this algorithm has the following deterministic bound: If the DataFrame has N + * elements and if we request the quantile at probability `p` up to error `err`, then the + * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is close + * to (p * N). More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile + * Summaries by Greenwald and Khanna. + * + * @param df + * the dataframe + * @param cols + * numerical columns of the dataframe + * @param probabilities + * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the + * minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError + * The relative target precision to achieve (greater than or equal 0). If set to zero, the exact + * quantiles are computed, which could be very expensive. Note that values greater than 1 are + * accepted but give the same result as 1. + * @return + * for each column, returns the requested approximations + * @note + * null and NaN values will be ignored in numerical columns before calculation. For a column + * only containing null or NaN values, an empty array is returned. + */ + def multipleApproxQuantiles( + df: DataFrame, + cols: Seq[String], + probabilities: Seq[Double], + relativeError: Double): Seq[Seq[Double]] = { + require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError") + val columns: Seq[Column] = cols.map { + colName => + val field = df.resolve(colName) + require( + field.dataType.isInstanceOf[NumericType], + s"Quantile calculation for column $colName with data type ${field.dataType}" + + " is not supported.") + Column(Cast(Column(colName).expr, DoubleType)) + } + val emptySummaries = Array.fill(cols.size)( + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) + + // Note that it works more or less by accident as `rdd.aggregate` is not a pure function: + // this function returns the same array as given in the input (because `aggregate` reuses + // the same argument). + def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { + var i = 0 + while (i < summaries.length) { + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } + i += 1 + } + summaries + } + + def merge( + sum1: Array[QuantileSummaries], + sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = { + sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) } + } + + val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge) + + summaries.map { + summary => + summary.query(probabilities) match { + case Some(q) => q + case None => Seq() + } + } + } + + /** Calculate the Pearson Correlation Coefficient for the given columns */ + def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols, "correlation") + counts.Ck / math.sqrt(counts.MkX * counts.MkY) + } + + /** Helper class to simplify tracking and merging counts. */ + private class CovarianceCounter extends Serializable { + var xAvg = 0.0 // the mean of all examples seen so far in col1 + var yAvg = 0.0 // the mean of all examples seen so far in col2 + var Ck = 0.0 // the co-moment after k examples + var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 + var MkY = 0.0 // sum of squares of differences from the (current) mean for col2 + var count = 0L // count of observed examples + + // add an example to the calculation + def add(x: Double, y: Double): this.type = { + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + this + } + + // merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + def merge(other: CovarianceCounter): this.type = { + if (other.count > 0) { + val totalCount = count + other.count + val deltaX = xAvg - other.xAvg + val deltaY = yAvg - other.yAvg + Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count + MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count + count = totalCount + } + this + } + + // return the sample covariance for the observed examples + def cov: Double = Ck / (count - 1) + } + + private def collectStatisticalData( + df: DataFrame, + cols: Seq[String], + functionName: String): CovarianceCounter = { + require( + cols.length == 2, + s"Currently $functionName calculation is supported " + + "between two columns.") + cols.map(name => (name, df.resolve(name))).foreach { + case (name, data) => + require( + data.dataType.isInstanceOf[NumericType], + s"Currently $functionName calculation " + + s"for columns with dataType ${data.dataType.catalogString} not supported." + ) + } + val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) + df.select(columns: _*) + .queryExecution + .toRdd + .treeAggregate(new CovarianceCounter)( + seqOp = (counter, row) => { + counter.add(row.getDouble(0), row.getDouble(1)) + }, + combOp = (baseCounter, other) => { + baseCounter.merge(other) + }) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * + * @param df + * The DataFrame + * @param cols + * the column names + * @return + * the covariance of the two columns. + */ + def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols, "covariance") + counts.cov + } + + /** Generate a table of frequencies for the elements of two columns. */ + def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + val tableName = s"${col1}_$col2" + val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) + if (counts.length == 1e6.toInt) { + logWarning( + "The maximum limit of 1e6 pairs have been collected, which may not be all of " + + "the pairs. Please try reducing the amount of distinct items in your columns.") + } + + def cleanElement(element: Any): String = { + if (element == null) "null" else element.toString + } + + // get the distinct sorted values of column 2, so that we can make them the column names + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap + val columnSize = distinctCol2.size + require( + columnSize < 1e4, + s"The number of distinct values for $col2, can't " + + s"exceed 1e4. Currently $columnSize") + val table = counts + .groupBy(_.get(0)) + .map { + case (col1Item, rows) => + val countsRow = new GenericInternalRow(columnSize + 1) + rows.foreach { + (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(2) is the frequency + val columnIndex = distinctCol2(cleanElement(row.get(1))) + countsRow.setLong(columnIndex + 1, row.getLong(2)) + } + // the value of col1 is the first value, the rest are the counts + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) + countsRow + } + .toSeq + + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } + + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { + r => StructField(cleanColumnName(r._1.toString), LongType) + } + val schema = StructType(StructField(tableName, StringType) +: headerNames) + + Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + } + + /** Calculate selected summary statistics for a dataset */ + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { + p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw QueryExecutionErrors.cannotParseStatisticAsPercentileError(p, e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + + def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { + TryCast(e, DoubleType) + } else { + e + } + + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { + stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile( + castAsDoubleIfNecessary(child), + Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) + .toAggregateExpression(), + Literal(index) + ) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "count_distinct" => + (child: Expression) => Count(child).toAggregateExpression(isDistinct = true) + case "approx_count_distinct" => + (child: Expression) => HyperLogLogPlusPlus(child).toAggregateExpression() + case "mean" => + (child: Expression) => Average(castAsDoubleIfNecessary(child)).toAggregateExpression() + case "stddev" => + (child: Expression) => + StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + } + } + } + + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) + + val aggExprs = statisticFns.flatMap { + func => selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } + + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.map(_.copy()).collect().head + + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } + + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) + } + rowIndex += 1 + } + + // All columns are string type + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } +} diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala index 59cea64751a4b..fbaf04cbc7c28 100644 --- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala @@ -21,16 +21,19 @@ import io.glutenproject.expression.{ExpressionNames, Sig} import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.SparkException +import org.apache.spark.paths.SparkPath import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanLike, PartitionedFileUtil, SparkPlan} import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.types.StructType @@ -114,6 +117,16 @@ class Spark34Shims extends SparkShims { } } + override def getBatchScanExecTable(batchScan: BatchScanExec): Table = batchScan.table + + override def generatePartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long, + @transient locations: Array[String] = Array.empty): PartitionedFile = + PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE", diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index 8c4de5cb1f07e..a47c44c7a7a4c 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.execution.datasources.HadoopFsRelation +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Expression, Predicate} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType @@ -61,4 +61,7 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = fileConstantMetadataColumns.nonEmpty def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema) + + @transient protected lazy val dynamicallySelectedPartitionsAlias: Array[PartitionDirectory] = + dynamicallySelectedPartitions } diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 68ea957c1ef66..e8e1b090a81ae 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -40,22 +40,8 @@ class BatchScanExecShim( output: Seq[AttributeReference], @transient scan: Scan, runtimeFilters: Seq[Expression], - keyGroupedPartitioning: Option[Seq[Expression]], - ordering: Option[Seq[SortOrder]], - @transient table: Table, - commonPartitionValues: Option[Seq[(InternalRow, Int)]], - applyPartialClustering: Boolean, - replicatePartitions: Boolean) - extends BatchScanExec( - output, - scan, - runtimeFilters, - keyGroupedPartitioning, - ordering, - table, - commonPartitionValues, - applyPartialClustering, - replicatePartitions) { + @transient table: Table) + extends BatchScanExec(output, scan, runtimeFilters, table = table) { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. @transient override lazy val metrics: Map[String, SQLMetric] = Map() diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark33Scan.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark34Scan.scala similarity index 100% rename from shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark33Scan.scala rename to shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark34Scan.scala diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala deleted file mode 100644 index 6536f80814744..0000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScan.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) - extends FileScan { - override def createReaderFactory(): PartitionReaderFactory = { - null - } - - override def dataSchema: StructType = readDataSchema -} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala deleted file mode 100644 index 475b18b68531b..0000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/velox/DwrfScanBuilder.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.velox - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex -import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class DwrfScanBuilder( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - schema: StructType, - dataSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { - - override def build(): Scan = { - DwrfScan( - sparkSession, - fileIndex, - readDataSchema(), - readPartitionSchema(), - pushedDataFilters, - options) - } - -} diff --git a/substrait/substrait-spark/pom.xml b/substrait/substrait-spark/pom.xml index 5a7f37cb0ee50..51fe2fd5b1084 100644 --- a/substrait/substrait-spark/pom.xml +++ b/substrait/substrait-spark/pom.xml @@ -15,6 +15,12 @@ Gluten Substrait Spark + + io.glutenproject + ${sparkshim.artifactId} + ${project.version} + compile + org.apache.spark spark-sql_${scala.binary.version} diff --git a/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index 2feedf0f7efff..2789d84f8f8e8 100644 --- a/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/substrait/substrait-spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -134,7 +134,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral) case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a) case a: Alias => translateUp(a.child) -// case p: PromotePrecision => translateUp(p.child) + case p: PromotePrecision => translateUp(p.child) case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue) case scalar @ ScalarFunction(children) => Util diff --git a/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 4d06687cf741a..f8cf3767938a6 100644 --- a/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/substrait/substrait-spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -255,6 +255,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { throw new UnsupportedOperationException( s"Unable to convert the plan to a substrait plan: $plan") } + private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = { toSubstraitExp(e, output) } @@ -335,9 +336,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { expr match { - case s @ ScalarSubquery(childPlan, outerAttrs, _, joinCond, _, _) - if outerAttrs.isEmpty && joinCond.isEmpty => - val rel = toSubstraitRel.visit(childPlan) + case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => + val rel = toSubstraitRel.visit(s.plan) Some( SExpression.ScalarSubquery.builder .input(rel) diff --git a/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index 836a087f1f53c..09b3ecc426c69 100644 --- a/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/substrait/substrait-spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -69,4 +69,6 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { override def visitSort(sort: Sort): Rel = t(sort) override def visitWithCTE(p: WithCTE): Rel = t(p) + + def visitOffset(p: Offset): Rel = t(p) } diff --git a/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index 345cb215f4ac9..081d6f93f5453 100644 --- a/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/substrait/substrait-spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -70,5 +70,7 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { override def visitWithCTE(p: WithCTE): Rel = t(p) + def visitOffset(p: Offset): Rel = t(p) + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) } diff --git a/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index 8962190171d62..ec3ee78e8c47c 100644 --- a/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/substrait/substrait-spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -71,5 +71,6 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { override def visitWithCTE(p: WithCTE): Rel = t(p) override def visitOffset(p: Offset): Rel = t(p) + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) }