Skip to content

Commit

Permalink
[GLUTEN-5414] [VL] Support datasource v2 scan csv (#5717)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored May 16, 2024
1 parent 0cbb7f2 commit 6a110e5
Show file tree
Hide file tree
Showing 21 changed files with 680 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.SparkPlanExecApi
import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
import org.apache.gluten.expression._
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
Expand Down Expand Up @@ -869,6 +870,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {

override def outputNativeColumnarSparkCompatibleData(plan: SparkPlan): Boolean = plan match {
case _: ArrowFileSourceScanExec => true
case _: ArrowBatchScanExec => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exception.SchemaMismatchException
import org.apache.gluten.execution.RowToVeloxColumnarExec
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool
import org.apache.gluten.utils.{ArrowUtil, Iterators}
import org.apache.gluten.vectorized.ArrowWritableColumnVector

Expand All @@ -41,6 +42,7 @@ import org.apache.spark.util.SerializableConfiguration

import org.apache.arrow.dataset.file.FileSystemDatasetFactory
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.VectorUnloader
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.conf.Configuration
Expand All @@ -66,55 +68,127 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
ArrowUtil.readSchema(files, fileFormat)
ArrowUtil.readSchema(
files,
fileFormat,
ArrowBufferAllocators.contextInstance(),
ArrowNativeMemoryPool.arrowPool("infer schema"))
}

override def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = true

private def checkHeader(
file: PartitionedFile,
override def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
parsedOptions: CSVOptions,
actualFilters: Seq[Filter],
conf: Configuration): Unit = {
val isStartOfFile = file.start == 0
if (!isStartOfFile) {
return
}
val actualDataSchema = StructType(
dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val actualRequiredSchema = StructType(
requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val parser =
new UnivocityParser(actualDataSchema, actualRequiredSchema, parsedOptions, actualFilters)
val schema = if (parsedOptions.columnPruning) actualRequiredSchema else actualDataSchema
val headerChecker = new CSVHeaderChecker(
schema,
parsedOptions,
source = s"CSV file: ${file.filePath}",
isStartOfFile)

val lines = {
val linesReader =
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get())
.foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map {
line => new String(line.getBytes, 0, line.getLength, parser.options.charset)
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val sqlConf = sparkSession.sessionState.conf
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val batchSize = sqlConf.columnBatchSize
val caseSensitive = sqlConf.caseSensitiveAnalysis
val columnPruning = sqlConf.csvColumnPruning &&
!requiredSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val parsedOptions = new CSVOptions(
options,
columnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val actualFilters =
filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord))
(file: PartitionedFile) => {
ArrowCSVFileFormat.checkHeader(
file,
dataSchema,
requiredSchema,
parsedOptions,
actualFilters,
broadcastedHadoopConf.value.value)
val factory =
ArrowUtil.makeArrowDiscovery(
URLDecoder.decode(file.filePath.toString, "UTF-8"),
fileFormat,
ArrowBufferAllocators.contextInstance(),
ArrowNativeMemoryPool.arrowPool("FileSystemDatasetFactory")
)
// todo predicate validation / pushdown
val fileFields = factory.inspect().getFields.asScala
// TODO: support array/map/struct types in out-of-order schema reading.
try {
val actualReadFields =
ArrowUtil.getRequestedField(requiredSchema, fileFields, caseSensitive)
ArrowCSVFileFormat
.readArrow(
ArrowBufferAllocators.contextInstance(),
file,
actualReadFields,
caseSensitive,
requiredSchema,
partitionSchema,
factory,
batchSize)
.asInstanceOf[Iterator[InternalRow]]
} catch {
case e: SchemaMismatchException =>
logWarning(e.getMessage)
val iter = ArrowCSVFileFormat.fallbackReadVanilla(
dataSchema,
requiredSchema,
broadcastedHadoopConf.value.value,
parsedOptions,
file,
actualFilters,
columnPruning)
val (schema, rows) =
ArrowCSVFileFormat.withPartitionValue(requiredSchema, partitionSchema, iter, file)
ArrowCSVFileFormat
.rowToColumn(schema, batchSize, rows)
.asInstanceOf[Iterator[InternalRow]]
case d: Exception => throw d
}

}
CSVHeaderCheckerHelper.checkHeaderColumnNames(headerChecker, lines, parser.tokenizer)
}

private def readArrow(
override def vectorTypes(
requiredSchema: StructType,
partitionSchema: StructType,
sqlConf: SQLConf): Option[Seq[String]] = {
Option(
Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)(
classOf[ArrowWritableColumnVector].getName
))
}

override def shortName(): String = "arrowcsv"

override def hashCode(): Int = getClass.hashCode()

override def equals(other: Any): Boolean = other.isInstanceOf[ArrowCSVFileFormat]

override def prepareWrite(
sparkSession: SparkSession,
job: _root_.org.apache.hadoop.mapreduce.Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException()
}
}

object ArrowCSVFileFormat {

def readArrow(
allocator: BufferAllocator,
file: PartitionedFile,
actualReadFields: Schema,
caseSensitive: Boolean,
requiredSchema: StructType,
partitionSchema: StructType,
factory: FileSystemDatasetFactory,
batchSize: Int): Iterator[InternalRow] = {
batchSize: Int): Iterator[ColumnarBatch] = {
val compare = ArrowUtil.compareStringFunc(caseSensitive)
val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray
val actualReadSchema = new StructType(
Expand Down Expand Up @@ -147,7 +221,9 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
override def next: ColumnarBatch = {
val root = reader.getVectorSchemaRoot
val unloader = new VectorUnloader(root)

val batch = ArrowUtil.loadBatch(
allocator,
unloader.getRecordBatch,
actualReadSchema,
requiredSchema,
Expand All @@ -166,13 +242,48 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
}
.recyclePayload(_.close())
.create()
.asInstanceOf[Iterator[InternalRow]]
}

private def rowToColumn(
def checkHeader(
file: PartitionedFile,
dataSchema: StructType,
requiredSchema: StructType,
parsedOptions: CSVOptions,
actualFilters: Seq[Filter],
conf: Configuration): Unit = {
val isStartOfFile = file.start == 0
if (!isStartOfFile) {
return
}
val actualDataSchema = StructType(
dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val actualRequiredSchema = StructType(
requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val parser =
new UnivocityParser(actualDataSchema, actualRequiredSchema, parsedOptions, actualFilters)
val schema = if (parsedOptions.columnPruning) actualRequiredSchema else actualDataSchema
val headerChecker = new CSVHeaderChecker(
schema,
parsedOptions,
source = s"CSV file: ${file.filePath}",
isStartOfFile)

val lines = {
val linesReader =
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get())
.foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map {
line => new String(line.getBytes, 0, line.getLength, parser.options.charset)
}
}
CSVHeaderCheckerHelper.checkHeaderColumnNames(headerChecker, lines, parser.tokenizer)
}

def rowToColumn(
schema: StructType,
batchSize: Int,
it: Iterator[InternalRow]): Iterator[InternalRow] = {
it: Iterator[InternalRow]): Iterator[ColumnarBatch] = {
// note, these metrics are unused but just make `RowToVeloxColumnarExec` happy
val numInputRows = new SQLMetric("numInputRows")
val numOutputBatches = new SQLMetric("numOutputBatches")
Expand All @@ -187,7 +298,6 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
)
veloxBatch
.map(v => ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance(), v))
.asInstanceOf[Iterator[InternalRow]]
}

private def toAttribute(field: StructField): AttributeReference =
Expand All @@ -197,7 +307,7 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
schema.map(toAttribute)
}

private def withPartitionValue(
def withPartitionValue(
requiredSchema: StructType,
partitionSchema: StructType,
iter: Iterator[InternalRow],
Expand All @@ -223,7 +333,7 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
}
}

private def fallbackReadVanilla(
def fallbackReadVanilla(
dataSchema: StructType,
requiredSchema: StructType,
conf: Configuration,
Expand All @@ -246,93 +356,4 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging
isStartOfFile)
CSVDataSource(parsedOptions).readFile(conf, file, parser, headerChecker, requiredSchema)
}

override def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val sqlConf = sparkSession.sessionState.conf
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val batchSize = sqlConf.columnBatchSize
val caseSensitive = sqlConf.caseSensitiveAnalysis
val columnPruning = sqlConf.csvColumnPruning &&
!requiredSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val parsedOptions = new CSVOptions(
options,
columnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val actualFilters =
filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord))
(file: PartitionedFile) => {
checkHeader(
file,
dataSchema,
requiredSchema,
parsedOptions,
actualFilters,
broadcastedHadoopConf.value.value)
val factory =
ArrowUtil.makeArrowDiscovery(URLDecoder.decode(file.filePath.toString, "UTF-8"), fileFormat)
// todo predicate validation / pushdown
val fileFields = factory.inspect().getFields.asScala
// TODO: support array/map/struct types in out-of-order schema reading.
try {
val actualReadFields =
ArrowUtil.getRequestedField(requiredSchema, fileFields, caseSensitive)
readArrow(
file,
actualReadFields,
caseSensitive,
requiredSchema,
partitionSchema,
factory,
batchSize)
} catch {
case e: SchemaMismatchException =>
logWarning(e.getMessage)
val iter = fallbackReadVanilla(
dataSchema,
requiredSchema,
broadcastedHadoopConf.value.value,
parsedOptions,
file,
actualFilters,
columnPruning)
val (schema, rows) = withPartitionValue(requiredSchema, partitionSchema, iter, file)
rowToColumn(schema, batchSize, rows)
case d: Exception => throw d
}

}
}

override def vectorTypes(
requiredSchema: StructType,
partitionSchema: StructType,
sqlConf: SQLConf): Option[Seq[String]] = {
Option(
Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)(
classOf[ArrowWritableColumnVector].getName
))
}

override def shortName(): String = "arrowcsv"

override def hashCode(): Int = getClass.hashCode()

override def equals(other: Any): Boolean = other.isInstanceOf[ArrowCSVFileFormat]

override def prepareWrite(
sparkSession: SparkSession,
job: _root_.org.apache.hadoop.mapreduce.Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException()
}
}
Loading

0 comments on commit 6a110e5

Please sign in to comment.