diff --git a/package/pom.xml b/package/pom.xml
index ba5886cc6dfbd..262391a676737 100644
--- a/package/pom.xml
+++ b/package/pom.xml
@@ -250,6 +250,8 @@
org.apache.spark.sql.execution.datasources.EmptyDirectoryDataWriter$
org.apache.spark.sql.execution.datasources.WriterBucketSpec
org.apache.spark.sql.execution.datasources.WriterBucketSpec$
+ org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+ org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 530f477ab742a..0e6ea140b05e6 100644
--- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -100,6 +100,32 @@ object FileFormatWriter extends Logging {
* @return
* The set of all partition paths that were updated during this write job.
*/
+
+ // scalastyle:off argcount
+ def write(
+ sparkSession: SparkSession,
+ plan: SparkPlan,
+ fileFormat: FileFormat,
+ committer: FileCommitProtocol,
+ outputSpec: OutputSpec,
+ hadoopConf: Configuration,
+ partitionColumns: Seq[Attribute],
+ bucketSpec: Option[BucketSpec],
+ statsTrackers: Seq[WriteJobStatsTracker],
+ options: Map[String, String]): Set[String] = write(
+ sparkSession = sparkSession,
+ plan = plan,
+ fileFormat = fileFormat,
+ committer = committer,
+ outputSpec = outputSpec,
+ hadoopConf = hadoopConf,
+ partitionColumns = partitionColumns,
+ bucketSpec = bucketSpec,
+ statsTrackers = statsTrackers,
+ options = options,
+ numStaticPartitionCols = 0
+ )
+
def write(
sparkSession: SparkSession,
plan: SparkPlan,
@@ -110,7 +136,8 @@ object FileFormatWriter extends Logging {
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
- options: Map[String, String]): Set[String] = {
+ options: Map[String, String],
+ numStaticPartitionCols: Int = 0): Set[String] = {
val nativeEnabled =
"true".equals(sparkSession.sparkContext.getLocalProperty("isNativeAppliable"))
@@ -215,8 +242,8 @@ object FileFormatWriter extends Logging {
)
// We should first sort by partition columns, then bucket id, and finally sorting columns.
- val requiredOrdering =
- partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
+ val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
+ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
@@ -349,6 +376,7 @@ object FileFormatWriter extends Logging {
throw QueryExecutionErrors.jobAbortedError(cause)
}
}
+ // scalastyle:on argcount
/** Writes data out in a single Spark task. */
private def executeTask(
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
new file mode 100644
index 0000000000000..b1e740284b560
--- /dev/null
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
+import org.apache.spark.sql.util.SchemaUtils
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+/**
+ * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
+ * Writing to dynamic partitions is also supported.
+ *
+ * @param staticPartitions
+ * partial partitioning spec for write. This defines the scope of partition overwrites: when the
+ * spec is empty, all partitions are overwritten. When it covers a prefix of the partition keys,
+ * only partitions matching the prefix are overwritten.
+ * @param ifPartitionNotExists
+ * If true, only write if the partition does not exist. Only valid for static partitions.
+ */
+
+// scalastyle:off line.size.limit
+case class InsertIntoHadoopFsRelationCommand(
+ outputPath: Path,
+ staticPartitions: TablePartitionSpec,
+ ifPartitionNotExists: Boolean,
+ partitionColumns: Seq[Attribute],
+ bucketSpec: Option[BucketSpec],
+ fileFormat: FileFormat,
+ options: Map[String, String],
+ query: LogicalPlan,
+ mode: SaveMode,
+ catalogTable: Option[CatalogTable],
+ fileIndex: Option[FileIndex],
+ outputColumnNames: Seq[String])
+ extends DataWritingCommand {
+
+ private lazy val parameters = CaseInsensitiveMap(options)
+
+ private[sql] lazy val dynamicPartitionOverwrite: Boolean = {
+ val partitionOverwriteMode = parameters
+ .get(DataSourceUtils.PARTITION_OVERWRITE_MODE)
+ // scalastyle:off caselocale
+ .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase))
+ // scalastyle:on caselocale
+ .getOrElse(conf.partitionOverwriteMode)
+ val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
+ // This config only makes sense when we are overwriting a partitioned dataset with dynamic
+ // partition columns.
+ enableDynamicOverwrite && mode == SaveMode.Overwrite &&
+ staticPartitions.size < partitionColumns.length
+ }
+
+ override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+ // Most formats don't do well with duplicate columns, so lets not allow that
+ SchemaUtils.checkColumnNameDuplication(
+ outputColumnNames,
+ s"when inserting into $outputPath",
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
+ val fs = outputPath.getFileSystem(hadoopConf)
+ val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+
+ val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions &&
+ catalogTable.isDefined &&
+ catalogTable.get.partitionColumnNames.nonEmpty &&
+ catalogTable.get.tracksPartitionsInCatalog
+
+ var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
+ var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+ var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty
+
+ // When partitions are tracked by the catalog, compute all custom partition locations that
+ // may be relevant to the insertion job.
+ if (partitionsTrackedByCatalog) {
+ matchingPartitions = sparkSession.sessionState.catalog
+ .listPartitions(catalogTable.get.identifier, Some(staticPartitions))
+ initialMatchingPartitions = matchingPartitions.map(_.spec)
+ customPartitionLocations =
+ getCustomPartitionLocations(fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
+ }
+
+ val jobId = java.util.UUID.randomUUID().toString
+ val committer = FileCommitProtocol.instantiate(
+ sparkSession.sessionState.conf.fileCommitProtocolClass,
+ jobId = jobId,
+ outputPath = outputPath.toString,
+ dynamicPartitionOverwrite = dynamicPartitionOverwrite)
+
+ val doInsertion = if (mode == SaveMode.Append) {
+ true
+ } else {
+ val pathExists = fs.exists(qualifiedOutputPath)
+ (mode, pathExists) match {
+ case (SaveMode.ErrorIfExists, true) =>
+ throw QueryCompilationErrors.outputPathAlreadyExistsError(qualifiedOutputPath)
+ case (SaveMode.Overwrite, true) =>
+ if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
+ false
+ } else if (dynamicPartitionOverwrite) {
+ // For dynamic partition overwrite, do not delete partition directories ahead.
+ true
+ } else {
+ deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
+ true
+ }
+ case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
+ true
+ case (SaveMode.Ignore, exists) =>
+ !exists
+ case (s, exists) =>
+ throw QueryExecutionErrors.saveModeUnsupportedError(s, exists)
+ }
+ }
+
+ if (doInsertion) {
+
+ def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = {
+ val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment)
+ if (partitionsTrackedByCatalog) {
+ val newPartitions = updatedPartitions -- initialMatchingPartitions
+ if (newPartitions.nonEmpty) {
+ AlterTableAddPartitionCommand(
+ catalogTable.get.identifier,
+ newPartitions.toSeq.map(p => (p, None)),
+ ifNotExists = true).run(sparkSession)
+ }
+ // For dynamic partition overwrite, we never remove partitions but only update existing
+ // ones.
+ if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) {
+ val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
+ if (deletedPartitions.nonEmpty) {
+ AlterTableDropPartitionCommand(
+ catalogTable.get.identifier,
+ deletedPartitions.toSeq,
+ ifExists = true,
+ purge = false,
+ retainData = true /* already deleted */ ).run(sparkSession)
+ }
+ }
+ }
+ }
+
+ // For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files
+ // will be renamed from staging path to final output path during commit job
+ val committerOutputPath = if (dynamicPartitionOverwrite) {
+ FileCommitProtocol
+ .getStagingDir(outputPath.toString, jobId)
+ .makeQualified(fs.getUri, fs.getWorkingDirectory)
+ } else {
+ qualifiedOutputPath
+ }
+
+ val updatedPartitionPaths =
+ FileFormatWriter.write(
+ sparkSession = sparkSession,
+ plan = child,
+ fileFormat = fileFormat,
+ committer = committer,
+ outputSpec = FileFormatWriter.OutputSpec(
+ committerOutputPath.toString,
+ customPartitionLocations,
+ outputColumns),
+ hadoopConf = hadoopConf,
+ partitionColumns = partitionColumns,
+ bucketSpec = bucketSpec,
+ statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
+ options = options,
+ numStaticPartitionCols = staticPartitions.size
+ )
+
+ // update metastore partition metadata
+ if (
+ updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty
+ && partitionColumns.length == staticPartitions.size
+ ) {
+ // Avoid empty static partition can't loaded to datasource table.
+ val staticPathFragment =
+ PartitioningUtils.getPathFragment(staticPartitions, partitionColumns)
+ refreshUpdatedPartitions(Set(staticPathFragment))
+ } else {
+ refreshUpdatedPartitions(updatedPartitionPaths)
+ }
+
+ // refresh cached files in FileIndex
+ fileIndex.foreach(_.refresh())
+ // refresh data cache if table is cached
+ sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, outputPath, fs)
+
+ if (catalogTable.nonEmpty) {
+ CommandUtils.updateTableStats(sparkSession, catalogTable.get)
+ }
+
+ } else {
+ logInfo("Skipping insertion into a relation that already exists.")
+ }
+
+ Seq.empty[Row]
+ }
+
+ /**
+ * Deletes all partition files that match the specified static prefix. Partitions with custom
+ * locations are also cleared based on the custom locations map given to this class.
+ */
+ private def deleteMatchingPartitions(
+ fs: FileSystem,
+ qualifiedOutputPath: Path,
+ customPartitionLocations: Map[TablePartitionSpec, String],
+ committer: FileCommitProtocol): Unit = {
+ val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
+ "/" + partitionColumns
+ .flatMap(p => staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)))
+ .mkString("/")
+ } else {
+ ""
+ }
+ // first clear the path determined by the static partition keys (e.g. /table/foo=1)
+ val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
+ if (fs.exists(staticPrefixPath) && !committer.deleteWithJob(fs, staticPrefixPath, true)) {
+ throw QueryExecutionErrors.cannotClearOutputDirectoryError(staticPrefixPath)
+ }
+ // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
+ for ((spec, customLoc) <- customPartitionLocations) {
+ assert(
+ (staticPartitions.toSet -- spec).isEmpty,
+ "Custom partition location did not match static partitioning keys")
+ val path = new Path(customLoc)
+ if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) {
+ throw QueryExecutionErrors.cannotClearPartitionDirectoryError(path)
+ }
+ }
+ }
+
+ /**
+ * Given a set of input partitions, returns those that have locations that differ from the Hive
+ * default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by the user.
+ *
+ * @return
+ * a mapping from partition specs to their custom locations
+ */
+ private def getCustomPartitionLocations(
+ fs: FileSystem,
+ table: CatalogTable,
+ qualifiedOutputPath: Path,
+ partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
+ partitions.flatMap {
+ p =>
+ val defaultLocation = qualifiedOutputPath
+ .suffix("/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema))
+ .toString
+ val catalogLocation =
+ new Path(p.location).makeQualified(fs.getUri, fs.getWorkingDirectory).toString
+ if (catalogLocation != defaultLocation) {
+ Some(p.spec -> catalogLocation)
+ } else {
+ None
+ }
+ }.toMap
+ }
+
+ override protected def withNewChildInternal(
+ newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild)
+}
+
+// scalastyle:on line.size.limit