Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-3459] Skip unnecessary local sort for static partition write #3460

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@
<ignoreClass>org.apache.spark.sql.execution.datasources.EmptyDirectoryDataWriter$</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.WriterBucketSpec</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.WriterBucketSpec$</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand</ignoreClass>
<ignoreClass>org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$</ignoreClass>

</ignoreClasses>
<scopes>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ object FileFormatWriter extends Logging {
* @return
* The set of all partition paths that were updated during this write job.
*/

// scalastyle:off argcount
def write(
sparkSession: SparkSession,
plan: SparkPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String]): Set[String] = write(
sparkSession = sparkSession,
plan = plan,
fileFormat = fileFormat,
committer = committer,
outputSpec = outputSpec,
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
statsTrackers = statsTrackers,
options = options,
numStaticPartitionCols = 0
)

def write(
sparkSession: SparkSession,
plan: SparkPlan,
Expand All @@ -118,7 +144,8 @@ object FileFormatWriter extends Logging {
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String]): Set[String] = {
options: Map[String, String],
numStaticPartitionCols: Int = 0): Set[String] = {

val nativeEnabled =
"true".equals(sparkSession.sparkContext.getLocalProperty("isNativeAppliable"))
Expand Down Expand Up @@ -195,7 +222,8 @@ object FileFormatWriter extends Logging {
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
Expand Down Expand Up @@ -328,6 +356,7 @@ object FileFormatWriter extends Logging {
throw QueryExecutionErrors.jobAbortedError(cause)
}
}
// scalastyle:on argcount

/** Writes data out in a single Spark task. */
private def executeTask(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.util.SchemaUtils

import org.apache.hadoop.fs.{FileSystem, Path}

/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
* Writing to dynamic partitions is also supported.
*
* @param staticPartitions
* partial partitioning spec for write. This defines the scope of partition overwrites: when the
* spec is empty, all partitions are overwritten. When it covers a prefix of the partition keys,
* only partitions matching the prefix are overwritten.
* @param ifPartitionNotExists
* If true, only write if the partition does not exist. Only valid for static partitions.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
extends DataWritingCommand {

private lazy val parameters = CaseInsensitiveMap(options)

private[sql] lazy val dynamicPartitionOverwrite: Boolean = {
val partitionOverwriteMode = parameters
.get("partitionOverwriteMode")
// scalastyle:off caselocale
.map(mode => PartitionOverwriteMode.withName(mode.toUpperCase))
// scalastyle:on caselocale
.getOrElse(conf.partitionOverwriteMode)
val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
// This config only makes sense when we are overwriting a partitioned dataset with dynamic
// partition columns.
enableDynamicOverwrite && mode == SaveMode.Overwrite &&
staticPartitions.size < partitionColumns.length
}

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkColumnNameDuplication(
outputColumnNames,
s"when inserting into $outputPath",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)

val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions &&
catalogTable.isDefined &&
catalogTable.get.partitionColumnNames.nonEmpty &&
catalogTable.get.tracksPartitionsInCatalog

var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty

// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
matchingPartitions = sparkSession.sessionState.catalog
.listPartitions(catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations =
getCustomPartitionLocations(fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
}

val jobId = java.util.UUID.randomUUID().toString
val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = jobId,
outputPath = outputPath.toString,
dynamicPartitionOverwrite = dynamicPartitionOverwrite)

val doInsertion = if (mode == SaveMode.Append) {
true
} else {
val pathExists = fs.exists(qualifiedOutputPath)
(mode, pathExists) match {
case (SaveMode.ErrorIfExists, true) =>
throw QueryCompilationErrors.outputPathAlreadyExistsError(qualifiedOutputPath)
case (SaveMode.Overwrite, true) =>
if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
false
} else if (dynamicPartitionOverwrite) {
// For dynamic partition overwrite, do not delete partition directories ahead.
true
} else {
deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
true
}
case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
case (SaveMode.Ignore, exists) =>
!exists
case (s, exists) =>
throw QueryExecutionErrors.unsupportedSaveModeError(s.toString, exists)
}
}

if (doInsertion) {

def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = {
val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment)
if (partitionsTrackedByCatalog) {
val newPartitions = updatedPartitions -- initialMatchingPartitions
if (newPartitions.nonEmpty) {
AlterTableAddPartitionCommand(
catalogTable.get.identifier,
newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(sparkSession)
}
// For dynamic partition overwrite, we never remove partitions but only update existing
// ones.
if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
catalogTable.get.identifier,
deletedPartitions.toSeq,
ifExists = true,
purge = false,
retainData = true /* already deleted */ ).run(sparkSession)
}
}
}
}

// For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files
// will be renamed from staging path to final output path during commit job
val committerOutputPath = if (dynamicPartitionOverwrite) {
FileCommitProtocol
.getStagingDir(outputPath.toString, jobId)
.makeQualified(fs.getUri, fs.getWorkingDirectory)
} else {
qualifiedOutputPath
}

val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
plan = child,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
committerOutputPath.toString,
customPartitionLocations,
outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options,
numStaticPartitionCols = staticPartitions.size
)

// update metastore partition metadata
if (
updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty
&& partitionColumns.length == staticPartitions.size
) {
// Avoid empty static partition can't loaded to datasource table.
val staticPathFragment =
PartitioningUtils.getPathFragment(staticPartitions, partitionColumns)
refreshUpdatedPartitions(Set(staticPathFragment))
} else {
refreshUpdatedPartitions(updatedPartitionPaths)
}

// refresh cached files in FileIndex
fileIndex.foreach(_.refresh())
// refresh data cache if table is cached
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, outputPath, fs)

if (catalogTable.nonEmpty) {
CommandUtils.updateTableStats(sparkSession, catalogTable.get)
}

} else {
logInfo("Skipping insertion into a relation that already exists.")
}

Seq.empty[Row]
}

/**
* Deletes all partition files that match the specified static prefix. Partitions with custom
* locations are also cleared based on the custom locations map given to this class.
*/
private def deleteMatchingPartitions(
fs: FileSystem,
qualifiedOutputPath: Path,
customPartitionLocations: Map[TablePartitionSpec, String],
committer: FileCommitProtocol): Unit = {
val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
"/" + partitionColumns
.flatMap(p => staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)))
.mkString("/")
} else {
""
}
// first clear the path determined by the static partition keys (e.g. /table/foo=1)
val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
if (fs.exists(staticPrefixPath) && !committer.deleteWithJob(fs, staticPrefixPath, true)) {
throw QueryExecutionErrors.cannotClearOutputDirectoryError(staticPrefixPath)
}
// now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
for ((spec, customLoc) <- customPartitionLocations) {
assert(
(staticPartitions.toSet -- spec).isEmpty,
"Custom partition location did not match static partitioning keys")
val path = new Path(customLoc)
if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) {
throw QueryExecutionErrors.cannotClearPartitionDirectoryError(path)
}
}
}

/**
* Given a set of input partitions, returns those that have locations that differ from the Hive
* default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by the user.
*
* @return
* a mapping from partition specs to their custom locations
*/
private def getCustomPartitionLocations(
fs: FileSystem,
table: CatalogTable,
qualifiedOutputPath: Path,
partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
partitions.flatMap {
p =>
val defaultLocation = qualifiedOutputPath
.suffix("/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema))
.toString
val catalogLocation =
new Path(p.location).makeQualified(fs.getUri, fs.getWorkingDirectory).toString
if (catalogLocation != defaultLocation) {
Some(p.spec -> catalogLocation)
} else {
None
}
}.toMap
}

override protected def withNewChildInternal(
newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild)
}
// scalastyle:on line.size.limit
Loading
Loading