Skip to content

Commit

Permalink
Support spark 3.4 shim layer in gluten
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Oct 10, 2023
1 parent 168d16b commit 44f86e3
Show file tree
Hide file tree
Showing 29 changed files with 1,153 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, Express
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
import io.glutenproject.extension.columnar.TransformHints
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.ExpressionNode
Expand Down Expand Up @@ -562,7 +563,7 @@ object FilterHandler {
batchScan.output,
scan,
leftFilters ++ newPartitionFilters,
batchScan.table)
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScan))
case _ =>
if (batchScan.runtimeFilters.isEmpty) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat

import org.apache.spark.rdd.RDD
Expand All @@ -45,22 +46,13 @@ class BatchScanExecTransformer(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
@transient table: Table,
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None,
@transient table: Table,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false)
extends BatchScanExecShim(
output,
scan,
runtimeFilters,
keyGroupedPartitioning,
ordering,
table,
commonPartitionValues,
applyPartialClustering,
replicatePartitions)
extends BatchScanExecShim(output, scan, runtimeFilters, table)
with BasicScanExecTransformer {

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
Expand Down Expand Up @@ -154,7 +146,7 @@ class BatchScanExecTransformer(
canonicalized.output,
canonicalized.scan,
canonicalized.runtimeFilters,
canonicalized.table
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(canonicalized)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,22 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.ConverterUtils
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater}
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.ReadRelNode

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PlanExpression}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.{FileSourceScanExecShim, InSubqueryExec, ScalarSubquery, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.{FileSourceScanExecShim, SparkPlan}
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.collection.BitSet

import java.util.concurrent.TimeUnit.NANOSECONDS

import scala.collection.JavaConverters

class FileSourceScanExecTransformer(
Expand Down Expand Up @@ -65,10 +63,10 @@ class FileSourceScanExecTransformer(
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics: Map[String, SQLMetric] =
BackendsApiManager.getMetricsApiInstance
.genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetrics
.genFileSourceScanTransformerMetrics(sparkContext) ++ staticMetricsAlias

/** SQL metrics generated only for scans using dynamic partition pruning. */
override protected lazy val staticMetrics =
private lazy val staticMetricsAlias =
if (partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) {
Map(
"staticFilesNum" -> SQLMetrics.createMetric(sparkContext, "static number of files read"),
Expand All @@ -93,7 +91,7 @@ class FileSourceScanExecTransformer(
override def getPartitions: Seq[InputPartition] =
BackendsApiManager.getTransformerApiInstance.genInputPartitionSeq(
relation,
dynamicallySelectedPartitions,
dynamicallySelectedPartitionsAlias,
output,
optionalBucketSet,
optionalNumCoalescedBuckets,
Expand Down Expand Up @@ -160,91 +158,6 @@ class FileSourceScanExecTransformer(
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genFileSourceScanTransformerMetricsUpdater(metrics)

// The codes below are copied from FileSourceScanExec in Spark,
// all of them are private.

/**
* Send the driver-side metrics. Before calling this function, selectedPartitions has been
* initialized. See SPARK-26327 for more details.
*/
override protected def sendDriverMetrics(): Unit = {
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverMetrics.values.toSeq)
}

protected def setFilesNumAndSizeMetric(
partitions: Seq[PartitionDirectory],
static: Boolean): Unit = {
val filesNum = partitions.map(_.files.size.toLong).sum
val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
if (!static || !partitionFilters.exists(FileSourceScanExecTransformer.isDynamicPruningFilter)) {
driverMetrics("numFiles").set(filesNum)
driverMetrics("filesSize").set(filesSize)
} else {
driverMetrics("staticFilesNum").set(filesNum)
driverMetrics("staticFilesSize").set(filesSize)
}
if (relation.partitionSchema.nonEmpty) {
driverMetrics("numPartitions").set(partitions.length)
}
}

@transient override lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
GlutenTimeMetric.withNanoTime {
val ret =
relation.location.listFiles(
partitionFilters.filterNot(FileSourceScanExecTransformer.isDynamicPruningFilter),
dataFilters)
setFilesNumAndSizeMetric(ret, static = true)
ret
}(t => driverMetrics("metadataTime").set(NANOSECONDS.toMillis(t + optimizerMetadataTimeNs)))
}.toArray

// We can only determine the actual partitions at runtime when a dynamic partition filter is
// present. This is because such a filter relies on information that is only available at run
// time (for instance the keys used in the other side of a join).
@transient override lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
val dynamicPartitionFilters =
partitionFilters.filter(FileSourceScanExecTransformer.isDynamicPruningFilter)
val selected = if (dynamicPartitionFilters.nonEmpty) {
// When it includes some DynamicPruningExpression,
// it needs to execute InSubqueryExec first,
// because doTransform path can't execute 'doExecuteColumnar' which will
// execute prepare subquery first.
dynamicPartitionFilters.foreach {
case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
executeInSubqueryForDynamicPruningExpression(inSubquery)
case e: Expression =>
e.foreach {
case s: ScalarSubquery => s.updateResult()
case _ =>
}
case _ =>
}
GlutenTimeMetric.withMillisTime {
// call the file index for the files matching all filters except dynamic partition filters
val predicate = dynamicPartitionFilters.reduce(And)
val partitionColumns = relation.partitionSchema
val boundPredicate = Predicate.create(
predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
},
Nil
)
val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
setFilesNumAndSizeMetric(ret, static = false)
ret
}(t => driverMetrics("pruningTime").set(t))
} else {
selectedPartitions
}
sendDriverMetrics()
selected
}

override val nodeNamePrefix: String = "NativeFile"

override val nodeName: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.glutenproject.execution._
import io.glutenproject.expression.ExpressionConverter
import io.glutenproject.extension.columnar._
import io.glutenproject.metrics.GlutenTimeMetric
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.{ColumnarShuffleUtil, LogLevelUtil, PhysicalPlanSelector}

import org.apache.spark.api.python.EvalPythonExecTransformer
Expand Down Expand Up @@ -578,8 +579,12 @@ case class TransformPreOverrides(isAdaptiveContext: Boolean)
case _ =>
ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters, reuseSubquery)
}
val transformer =
new BatchScanExecTransformer(plan.output, plan.scan, newPartitionFilters, plan.table)
val transformer = new BatchScanExecTransformer(
plan.output,
plan.scan,
newPartitionFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan))

val validationResult = transformer.doValidate()
if (validationResult.isValid) {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution._
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.PhysicalPlanSelector

import org.apache.spark.api.python.EvalPythonExecTransformer
Expand Down Expand Up @@ -333,12 +334,11 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
if (plan.runtimeFilters.nonEmpty) {
TransformHints.tagTransformable(plan)
} else {
val transformer =
new BatchScanExecTransformer(
plan.output,
plan.scan,
plan.runtimeFilters,
plan.table)
val transformer = new BatchScanExecTransformer(
plan.output,
plan.scan,
plan.runtimeFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(plan))
TransformHints.tag(plan, transformer.doValidate().toTransformHint)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.execution.{GlutenMergeTreePartition, GlutenPartition}
import io.glutenproject.softaffinity.SoftAffinityManager
import io.glutenproject.softaffinity.scheduler.SoftAffinityListener
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.plan.PlanBuilder

import org.apache.spark.SparkConf
import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListenerExecutorRemoved}
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.test.SharedSparkSession

class SoftAffinitySuite extends QueryTest with SharedSparkSession with PredicateHelper {
Expand All @@ -43,18 +43,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
PartitionedFile(
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath0"),
"fakePath0",
0,
100,
Array("host-1", "host-2")),
PartitionedFile(
Array("host-1", "host-2")
),
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath1"),
"fakePath1",
0,
200,
Array("host-2", "host-3"))
Array("host-2", "host-3")
)
).toArray
)

Expand All @@ -70,18 +72,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
PartitionedFile(
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath0"),
"fakePath0",
0,
100,
Array("host-1", "host-2")),
PartitionedFile(
Array("host-1", "host-2")
),
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath1"),
"fakePath1",
0,
200,
Array("host-4", "host-5"))
Array("host-4", "host-5")
)
).toArray
)

Expand All @@ -98,18 +102,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
PartitionedFile(
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath0"),
"fakePath0",
0,
100,
Array("host-1", "host-2")),
PartitionedFile(
Array("host-1", "host-2")
),
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath1"),
"fakePath1",
0,
200,
Array("host-5", "host-6"))
Array("host-5", "host-6")
)
).toArray
)

Expand Down Expand Up @@ -138,18 +144,20 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate
val partition = FilePartition(
0,
Seq(
PartitionedFile(
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath0"),
"fakePath0",
0,
100,
Array("host-1", "host-2")),
PartitionedFile(
Array("host-1", "host-2")
),
SparkShimLoader.getSparkShims.generatePartitionedFile(
InternalRow.empty,
SparkPath.fromPathString("fakePath1"),
"fakePath1",
0,
200,
Array("host-5", "host-6"))
Array("host-5", "host-6")
)
).toArray
)

Expand Down
Loading

0 comments on commit 44f86e3

Please sign in to comment.