Skip to content

Commit

Permalink
[GLUTEN-3459] Skip unnecessary local sort for static partition write
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 committed Oct 23, 2023
1 parent 327b8eb commit a3c4f5f
Show file tree
Hide file tree
Showing 5 changed files with 643 additions and 5 deletions.
2 changes: 2 additions & 0 deletions package/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@
<ignoreClass>org.apache.spark.sql.execution.datasources.EmptyDirectoryDataWriter$</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.WriterBucketSpec</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.WriterBucketSpec$</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$</ignoreClass>

</ignoreClasses>
<scopes>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ object FileFormatWriter extends Logging {
* @return
* The set of all partition paths that were updated during this write job.
*/

// scalastyle:off argcount
def write(
sparkSession: SparkSession,
plan: SparkPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String]): Set[String] = write(
sparkSession = sparkSession,
plan = plan,
fileFormat = fileFormat,
committer = committer,
outputSpec = outputSpec,
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
statsTrackers = statsTrackers,
options = options,
numStaticPartitionCols = 0
)

def write(
sparkSession: SparkSession,
plan: SparkPlan,
Expand All @@ -118,7 +144,8 @@ object FileFormatWriter extends Logging {
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String]): Set[String] = {
options: Map[String, String],
numStaticPartitionCols: Int = 0): Set[String] = {

val nativeEnabled =
"true".equals(sparkSession.sparkContext.getLocalProperty("isNativeAppliable"))
Expand Down Expand Up @@ -195,7 +222,8 @@ object FileFormatWriter extends Logging {
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
Expand Down Expand Up @@ -328,6 +356,7 @@ object FileFormatWriter extends Logging {
throw QueryExecutionErrors.jobAbortedError(cause)
}
}
// scalastyle:on argcount

/** Writes data out in a single Spark task. */
private def executeTask(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.util.SchemaUtils

import org.apache.hadoop.fs.{FileSystem, Path}

/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
* Writing to dynamic partitions is also supported.
*
* @param staticPartitions
* partial partitioning spec for write. This defines the scope of partition overwrites: when the
* spec is empty, all partitions are overwritten. When it covers a prefix of the partition keys,
* only partitions matching the prefix are overwritten.
* @param ifPartitionNotExists
* If true, only write if the partition does not exist. Only valid for static partitions.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
extends DataWritingCommand {

private lazy val parameters = CaseInsensitiveMap(options)

private[sql] lazy val dynamicPartitionOverwrite: Boolean = {
val partitionOverwriteMode = parameters
.get("partitionOverwriteMode")
// scalastyle:off caselocale
.map(mode => PartitionOverwriteMode.withName(mode.toUpperCase))
// scalastyle:on caselocale
.getOrElse(conf.partitionOverwriteMode)
val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
// This config only makes sense when we are overwriting a partitioned dataset with dynamic
// partition columns.
enableDynamicOverwrite && mode == SaveMode.Overwrite &&
staticPartitions.size < partitionColumns.length
}

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkColumnNameDuplication(
outputColumnNames,
s"when inserting into $outputPath",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)

val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions &&
catalogTable.isDefined &&
catalogTable.get.partitionColumnNames.nonEmpty &&
catalogTable.get.tracksPartitionsInCatalog

var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty

// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
matchingPartitions = sparkSession.sessionState.catalog
.listPartitions(catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations =
getCustomPartitionLocations(fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
}

val jobId = java.util.UUID.randomUUID().toString
val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = jobId,
outputPath = outputPath.toString,
dynamicPartitionOverwrite = dynamicPartitionOverwrite)

val doInsertion = if (mode == SaveMode.Append) {
true
} else {
val pathExists = fs.exists(qualifiedOutputPath)
(mode, pathExists) match {
case (SaveMode.ErrorIfExists, true) =>
throw QueryCompilationErrors.outputPathAlreadyExistsError(qualifiedOutputPath)
case (SaveMode.Overwrite, true) =>
if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
false
} else if (dynamicPartitionOverwrite) {
// For dynamic partition overwrite, do not delete partition directories ahead.
true
} else {
deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
true
}
case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
case (SaveMode.Ignore, exists) =>
!exists
case (s, exists) =>
throw QueryExecutionErrors.unsupportedSaveModeError(s.toString, exists)
}
}

if (doInsertion) {

def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = {
val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment)
if (partitionsTrackedByCatalog) {
val newPartitions = updatedPartitions -- initialMatchingPartitions
if (newPartitions.nonEmpty) {
AlterTableAddPartitionCommand(
catalogTable.get.identifier,
newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(sparkSession)
}
// For dynamic partition overwrite, we never remove partitions but only update existing
// ones.
if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
catalogTable.get.identifier,
deletedPartitions.toSeq,
ifExists = true,
purge = false,
retainData = true /* already deleted */ ).run(sparkSession)
}
}
}
}

// For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files
// will be renamed from staging path to final output path during commit job
val committerOutputPath = if (dynamicPartitionOverwrite) {
FileCommitProtocol
.getStagingDir(outputPath.toString, jobId)
.makeQualified(fs.getUri, fs.getWorkingDirectory)
} else {
qualifiedOutputPath
}

val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
plan = child,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
committerOutputPath.toString,
customPartitionLocations,
outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options,
numStaticPartitionCols = staticPartitions.size
)

// update metastore partition metadata
if (
updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty
&& partitionColumns.length == staticPartitions.size
) {
// Avoid empty static partition can't loaded to datasource table.
val staticPathFragment =
PartitioningUtils.getPathFragment(staticPartitions, partitionColumns)
refreshUpdatedPartitions(Set(staticPathFragment))
} else {
refreshUpdatedPartitions(updatedPartitionPaths)
}

// refresh cached files in FileIndex
fileIndex.foreach(_.refresh())
// refresh data cache if table is cached
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, outputPath, fs)

if (catalogTable.nonEmpty) {
CommandUtils.updateTableStats(sparkSession, catalogTable.get)
}

} else {
logInfo("Skipping insertion into a relation that already exists.")
}

Seq.empty[Row]
}

/**
* Deletes all partition files that match the specified static prefix. Partitions with custom
* locations are also cleared based on the custom locations map given to this class.
*/
private def deleteMatchingPartitions(
fs: FileSystem,
qualifiedOutputPath: Path,
customPartitionLocations: Map[TablePartitionSpec, String],
committer: FileCommitProtocol): Unit = {
val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
"/" + partitionColumns
.flatMap(p => staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)))
.mkString("/")
} else {
""
}
// first clear the path determined by the static partition keys (e.g. /table/foo=1)
val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
if (fs.exists(staticPrefixPath) && !committer.deleteWithJob(fs, staticPrefixPath, true)) {
throw QueryExecutionErrors.cannotClearOutputDirectoryError(staticPrefixPath)
}
// now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
for ((spec, customLoc) <- customPartitionLocations) {
assert(
(staticPartitions.toSet -- spec).isEmpty,
"Custom partition location did not match static partitioning keys")
val path = new Path(customLoc)
if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) {
throw QueryExecutionErrors.cannotClearPartitionDirectoryError(path)
}
}
}

/**
* Given a set of input partitions, returns those that have locations that differ from the Hive
* default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by the user.
*
* @return
* a mapping from partition specs to their custom locations
*/
private def getCustomPartitionLocations(
fs: FileSystem,
table: CatalogTable,
qualifiedOutputPath: Path,
partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
partitions.flatMap {
p =>
val defaultLocation = qualifiedOutputPath
.suffix("/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema))
.toString
val catalogLocation =
new Path(p.location).makeQualified(fs.getUri, fs.getWorkingDirectory).toString
if (catalogLocation != defaultLocation) {
Some(p.spec -> catalogLocation)
} else {
None
}
}.toMap
}

override protected def withNewChildInternal(
newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild)
}
// scalastyle:on line.size.limit
Loading

0 comments on commit a3c4f5f

Please sign in to comment.