diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 33ce1ee72550..f54bf9b3f61e 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.SparkPlanExecApi import org.apache.gluten.datasource.ArrowConvertorRule import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ +import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec import org.apache.gluten.expression._ import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} @@ -869,6 +870,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { override def outputNativeColumnarSparkCompatibleData(plan: SparkPlan): Boolean = plan match { case _: ArrowFileSourceScanExec => true + case _: ArrowBatchScanExec => true case _ => false } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala index c05af24ff611..0f6813d8fc6a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala @@ -20,6 +20,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.exception.SchemaMismatchException import org.apache.gluten.execution.RowToVeloxColumnarExec import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators +import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool import org.apache.gluten.utils.{ArrowUtil, Iterators} import org.apache.gluten.vectorized.ArrowWritableColumnVector @@ -41,6 +42,7 @@ import org.apache.spark.util.SerializableConfiguration import org.apache.arrow.dataset.file.FileSystemDatasetFactory import org.apache.arrow.dataset.scanner.ScanOptions +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorUnloader import org.apache.arrow.vector.types.pojo.Schema import org.apache.hadoop.conf.Configuration @@ -66,55 +68,127 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - ArrowUtil.readSchema(files, fileFormat) + ArrowUtil.readSchema( + files, + fileFormat, + ArrowBufferAllocators.contextInstance(), + ArrowNativeMemoryPool.arrowPool("infer schema")) } override def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = true - private def checkHeader( - file: PartitionedFile, + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, dataSchema: StructType, + partitionSchema: StructType, requiredSchema: StructType, - parsedOptions: CSVOptions, - actualFilters: Seq[Filter], - conf: Configuration): Unit = { - val isStartOfFile = file.start == 0 - if (!isStartOfFile) { - return - } - val actualDataSchema = StructType( - dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val actualRequiredSchema = StructType( - requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val parser = - new UnivocityParser(actualDataSchema, actualRequiredSchema, parsedOptions, actualFilters) - val schema = if (parsedOptions.columnPruning) actualRequiredSchema else actualDataSchema - val headerChecker = new CSVHeaderChecker( - schema, - parsedOptions, - source = s"CSV file: ${file.filePath}", - isStartOfFile) - - val lines = { - val linesReader = - new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) - Option(TaskContext.get()) - .foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) - linesReader.map { - line => new String(line.getBytes, 0, line.getLength, parser.options.charset) + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = sparkSession.sessionState.conf + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val batchSize = sqlConf.columnBatchSize + val caseSensitive = sqlConf.caseSensitiveAnalysis + val columnPruning = sqlConf.csvColumnPruning && + !requiredSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new CSVOptions( + options, + columnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val actualFilters = + filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) + (file: PartitionedFile) => { + ArrowCSVFileFormat.checkHeader( + file, + dataSchema, + requiredSchema, + parsedOptions, + actualFilters, + broadcastedHadoopConf.value.value) + val factory = + ArrowUtil.makeArrowDiscovery( + URLDecoder.decode(file.filePath.toString, "UTF-8"), + fileFormat, + ArrowBufferAllocators.contextInstance(), + ArrowNativeMemoryPool.arrowPool("FileSystemDatasetFactory") + ) + // todo predicate validation / pushdown + val fileFields = factory.inspect().getFields.asScala + // TODO: support array/map/struct types in out-of-order schema reading. + try { + val actualReadFields = + ArrowUtil.getRequestedField(requiredSchema, fileFields, caseSensitive) + ArrowCSVFileFormat + .readArrow( + ArrowBufferAllocators.contextInstance(), + file, + actualReadFields, + caseSensitive, + requiredSchema, + partitionSchema, + factory, + batchSize) + .asInstanceOf[Iterator[InternalRow]] + } catch { + case e: SchemaMismatchException => + logWarning(e.getMessage) + val iter = ArrowCSVFileFormat.fallbackReadVanilla( + dataSchema, + requiredSchema, + broadcastedHadoopConf.value.value, + parsedOptions, + file, + actualFilters, + columnPruning) + val (schema, rows) = + ArrowCSVFileFormat.withPartitionValue(requiredSchema, partitionSchema, iter, file) + ArrowCSVFileFormat + .rowToColumn(schema, batchSize, rows) + .asInstanceOf[Iterator[InternalRow]] + case d: Exception => throw d } + } - CSVHeaderCheckerHelper.checkHeaderColumnNames(headerChecker, lines, parser.tokenizer) } - private def readArrow( + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + Option( + Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)( + classOf[ArrowWritableColumnVector].getName + )) + } + + override def shortName(): String = "arrowcsv" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[ArrowCSVFileFormat] + + override def prepareWrite( + sparkSession: SparkSession, + job: _root_.org.apache.hadoop.mapreduce.Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException() + } +} + +object ArrowCSVFileFormat { + + def readArrow( + allocator: BufferAllocator, file: PartitionedFile, actualReadFields: Schema, caseSensitive: Boolean, requiredSchema: StructType, partitionSchema: StructType, factory: FileSystemDatasetFactory, - batchSize: Int): Iterator[InternalRow] = { + batchSize: Int): Iterator[ColumnarBatch] = { val compare = ArrowUtil.compareStringFunc(caseSensitive) val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray val actualReadSchema = new StructType( @@ -147,7 +221,9 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging override def next: ColumnarBatch = { val root = reader.getVectorSchemaRoot val unloader = new VectorUnloader(root) + val batch = ArrowUtil.loadBatch( + allocator, unloader.getRecordBatch, actualReadSchema, requiredSchema, @@ -166,13 +242,48 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging } .recyclePayload(_.close()) .create() - .asInstanceOf[Iterator[InternalRow]] } - private def rowToColumn( + def checkHeader( + file: PartitionedFile, + dataSchema: StructType, + requiredSchema: StructType, + parsedOptions: CSVOptions, + actualFilters: Seq[Filter], + conf: Configuration): Unit = { + val isStartOfFile = file.start == 0 + if (!isStartOfFile) { + return + } + val actualDataSchema = StructType( + dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val actualRequiredSchema = StructType( + requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val parser = + new UnivocityParser(actualDataSchema, actualRequiredSchema, parsedOptions, actualFilters) + val schema = if (parsedOptions.columnPruning) actualRequiredSchema else actualDataSchema + val headerChecker = new CSVHeaderChecker( + schema, + parsedOptions, + source = s"CSV file: ${file.filePath}", + isStartOfFile) + + val lines = { + val linesReader = + new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) + linesReader.map { + line => new String(line.getBytes, 0, line.getLength, parser.options.charset) + } + } + CSVHeaderCheckerHelper.checkHeaderColumnNames(headerChecker, lines, parser.tokenizer) + } + + def rowToColumn( schema: StructType, batchSize: Int, - it: Iterator[InternalRow]): Iterator[InternalRow] = { + it: Iterator[InternalRow]): Iterator[ColumnarBatch] = { // note, these metrics are unused but just make `RowToVeloxColumnarExec` happy val numInputRows = new SQLMetric("numInputRows") val numOutputBatches = new SQLMetric("numOutputBatches") @@ -187,7 +298,6 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging ) veloxBatch .map(v => ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance(), v)) - .asInstanceOf[Iterator[InternalRow]] } private def toAttribute(field: StructField): AttributeReference = @@ -197,7 +307,7 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging schema.map(toAttribute) } - private def withPartitionValue( + def withPartitionValue( requiredSchema: StructType, partitionSchema: StructType, iter: Iterator[InternalRow], @@ -223,7 +333,7 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging } } - private def fallbackReadVanilla( + def fallbackReadVanilla( dataSchema: StructType, requiredSchema: StructType, conf: Configuration, @@ -246,93 +356,4 @@ class ArrowCSVFileFormat extends FileFormat with DataSourceRegister with Logging isStartOfFile) CSVDataSource(parsedOptions).readFile(conf, file, parser, headerChecker, requiredSchema) } - - override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val sqlConf = sparkSession.sessionState.conf - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val batchSize = sqlConf.columnBatchSize - val caseSensitive = sqlConf.caseSensitiveAnalysis - val columnPruning = sqlConf.csvColumnPruning && - !requiredSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val parsedOptions = new CSVOptions( - options, - columnPruning, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val actualFilters = - filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) - (file: PartitionedFile) => { - checkHeader( - file, - dataSchema, - requiredSchema, - parsedOptions, - actualFilters, - broadcastedHadoopConf.value.value) - val factory = - ArrowUtil.makeArrowDiscovery(URLDecoder.decode(file.filePath.toString, "UTF-8"), fileFormat) - // todo predicate validation / pushdown - val fileFields = factory.inspect().getFields.asScala - // TODO: support array/map/struct types in out-of-order schema reading. - try { - val actualReadFields = - ArrowUtil.getRequestedField(requiredSchema, fileFields, caseSensitive) - readArrow( - file, - actualReadFields, - caseSensitive, - requiredSchema, - partitionSchema, - factory, - batchSize) - } catch { - case e: SchemaMismatchException => - logWarning(e.getMessage) - val iter = fallbackReadVanilla( - dataSchema, - requiredSchema, - broadcastedHadoopConf.value.value, - parsedOptions, - file, - actualFilters, - columnPruning) - val (schema, rows) = withPartitionValue(requiredSchema, partitionSchema, iter, file) - rowToColumn(schema, batchSize, rows) - case d: Exception => throw d - } - - } - } - - override def vectorTypes( - requiredSchema: StructType, - partitionSchema: StructType, - sqlConf: SQLConf): Option[Seq[String]] = { - Option( - Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)( - classOf[ArrowWritableColumnVector].getName - )) - } - - override def shortName(): String = "arrowcsv" - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[ArrowCSVFileFormat] - - override def prepareWrite( - sparkSession: SparkSession, - job: _root_.org.apache.hadoop.mapreduce.Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { - throw new UnsupportedOperationException() - } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala index e29313a3809e..dab1ffd3b9e3 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala @@ -17,6 +17,7 @@ package org.apache.gluten.datasource import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.datasource.v2.ArrowCSVTable import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.annotation.Experimental @@ -27,11 +28,15 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.PermissiveMode import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.csv.CSVTable import org.apache.spark.sql.types.StructType import org.apache.spark.sql.utils.SparkSchemaUtil import java.nio.charset.StandardCharsets +import scala.collection.convert.ImplicitConversions.`map AsScala` + @Experimental case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { @@ -39,27 +44,49 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { return plan } plan.resolveOperators { - // Read path case l @ LogicalRelation( r @ HadoopFsRelation(_, _, dataSchema, _, _: CSVFileFormat, options), _, _, - _) => - val csvOptions = new CSVOptions( + _) if validate(session, dataSchema, options) => + l.copy(relation = r.copy(fileFormat = new ArrowCSVFileFormat())(session)) + case d @ DataSourceV2Relation( + t @ CSVTable( + name, + sparkSession, + options, + paths, + userSpecifiedSchema, + fallbackFileFormat), + _, + _, + _, + _) if validate(session, t.dataSchema, options.asCaseSensitiveMap().toMap) => + d.copy(table = ArrowCSVTable( + "arrow" + name, + sparkSession, options, - columnPruning = session.sessionState.conf.csvColumnPruning, - session.sessionState.conf.sessionLocalTimeZone) - if ( - checkSchema(dataSchema) && - checkCsvOptions(csvOptions, session.sessionState.conf.sessionLocalTimeZone) - ) { - l.copy(relation = r.copy(fileFormat = new ArrowCSVFileFormat())(session)) - } else l + paths, + userSpecifiedSchema, + fallbackFileFormat)) case r => r } } + private def validate( + session: SparkSession, + dataSchema: StructType, + options: Map[String, String]): Boolean = { + val csvOptions = new CSVOptions( + options, + columnPruning = session.sessionState.conf.csvColumnPruning, + session.sessionState.conf.sessionLocalTimeZone) + checkSchema(dataSchema) && + checkCsvOptions(csvOptions, session.sessionState.conf.sessionLocalTimeZone) && + dataSchema.nonEmpty + } + private def checkCsvOptions(csvOptions: CSVOptions, timeZone: String): Boolean = { csvOptions.headerFlag && !csvOptions.multiLine && csvOptions.delimiter == "," && csvOptions.quote == '\"' && diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVPartitionReaderFactory.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVPartitionReaderFactory.scala new file mode 100644 index 000000000000..ddc7f797fb93 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVPartitionReaderFactory.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.datasource.v2 + +import org.apache.gluten.datasource.ArrowCSVFileFormat +import org.apache.gluten.exception.SchemaMismatchException +import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators +import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool +import org.apache.gluten.utils.ArrowUtil + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.{SerializableConfiguration, TaskResources} + +import org.apache.arrow.dataset.file.FileFormat + +import java.net.URLDecoder + +import scala.collection.JavaConverters.asScalaBufferConverter + +case class ArrowCSVPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CSVOptions, + filters: Seq[Filter]) + extends FilePartitionReaderFactory + with Logging { + + private val batchSize = sqlConf.parquetVectorizedReaderBatchSize + private val caseSensitive: Boolean = sqlConf.caseSensitiveAnalysis + private val csvColumnPruning: Boolean = sqlConf.csvColumnPruning + + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + // disable row based read + throw new UnsupportedOperationException + } + + override def buildColumnarReader( + partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = { + val actualDataSchema = StructType( + dataSchema.filterNot(_.name == options.columnNameOfCorruptRecord)) + val actualReadDataSchema = StructType( + readDataSchema.filterNot(_.name == options.columnNameOfCorruptRecord)) + ArrowCSVFileFormat.checkHeader( + partitionedFile, + actualDataSchema, + actualReadDataSchema, + options, + filters, + broadcastedConf.value.value) + val (allocator, pool) = if (!TaskResources.inSparkTask()) { + TaskResources.runUnsafe( + ( + ArrowBufferAllocators.contextInstance(), + ArrowNativeMemoryPool.arrowPool("FileSystemFactory")) + ) + } else { + ( + ArrowBufferAllocators.contextInstance(), + ArrowNativeMemoryPool.arrowPool("FileSystemFactory")) + } + val factory = ArrowUtil.makeArrowDiscovery( + URLDecoder.decode(partitionedFile.filePath.toString(), "UTF-8"), + FileFormat.CSV, + allocator, + pool) + val parquetFileFields = factory.inspect().getFields.asScala + // TODO: support array/map/struct types in out-of-order schema reading. + val iter = + try { + val actualReadFields = + ArrowUtil.getRequestedField(readDataSchema, parquetFileFields, caseSensitive) + ArrowCSVFileFormat.readArrow( + allocator, + partitionedFile, + actualReadFields, + caseSensitive, + readDataSchema, + readPartitionSchema, + factory, + batchSize) + } catch { + case e: SchemaMismatchException => + logWarning(e.getMessage) + val iter = ArrowCSVFileFormat.fallbackReadVanilla( + dataSchema, + readDataSchema, + broadcastedConf.value.value, + options, + partitionedFile, + filters, + csvColumnPruning) + val (schema, rows) = ArrowCSVFileFormat.withPartitionValue( + readDataSchema, + readPartitionSchema, + iter, + partitionedFile) + ArrowCSVFileFormat.rowToColumn(schema, batchSize, rows) + case d: Exception => throw d + } + + new PartitionReader[ColumnarBatch] { + + override def next(): Boolean = { + iter.hasNext + } + + override def get(): ColumnarBatch = { + iter.next() + } + + override def close(): Unit = {} + } + } + +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScan.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScan.scala new file mode 100644 index 000000000000..ce3f84770464 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScan.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.datasource.v2 + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import org.apache.hadoop.fs.Path + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +case class ArrowCSVScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + pushedFilters: Array[Filter], + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + private lazy val parsedOptions: CSVOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord + ) + + override def isSplitable(path: Path): Boolean = { + false + } + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap().asScala.toMap + val hconf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hconf)) + val actualFilters = + pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) + ArrowCSVPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + parsedOptions, + actualFilters) + } + + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScanBuilder.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScanBuilder.scala new file mode 100644 index 000000000000..2b3991fe2984 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVScanBuilder.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.datasource.v2 + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ArrowCSVScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ArrowCSVScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + Array.empty, + options) + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVTable.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVTable.scala new file mode 100644 index 000000000000..aa7f737f9cfc --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/v2/ArrowCSVTable.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.datasource.v2 + +import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators +import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool +import org.apache.gluten.utils.ArrowUtil + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.TaskResources + +import org.apache.hadoop.fs.FileStatus + +case class ArrowCSVTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + val (allocator, pool) = if (!TaskResources.inSparkTask()) { + TaskResources.runUnsafe( + (ArrowBufferAllocators.contextInstance(), ArrowNativeMemoryPool.arrowPool("inferSchema")) + ) + } else { + (ArrowBufferAllocators.contextInstance(), ArrowNativeMemoryPool.arrowPool("inferSchema")) + } + ArrowUtil.readSchema( + files.head, + org.apache.arrow.dataset.file.FileFormat.CSV, + allocator, + pool + ) + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ArrowCSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + throw new UnsupportedOperationException + } + + override def formatName: String = "arrowcsv" +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/datasource/v2/ArrowBatchScanExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/datasource/v2/ArrowBatchScanExec.scala new file mode 100644 index 000000000000..3c1c538207c5 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/datasource/v2/ArrowBatchScanExec.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution.datasource.v2 + +import org.apache.gluten.extension.GlutenPlan + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.connector.read.{Batch, PartitionReaderFactory, Scan} +import org.apache.spark.sql.execution.datasources.v2.{ArrowBatchScanExecShim, BatchScanExec} + +case class ArrowBatchScanExec(original: BatchScanExec) + extends ArrowBatchScanExecShim(original) + with GlutenPlan { + + @transient lazy val batch: Batch = original.batch + + override lazy val readerFactory: PartitionReaderFactory = original.readerFactory + + override lazy val inputRDD: RDD[InternalRow] = original.inputRDD + + override def outputPartitioning: Partitioning = original.outputPartitioning + + override def scan: Scan = original.scan + + override def doCanonicalize(): ArrowBatchScanExec = + this.copy(original = original.doCanonicalize()) + + override def nodeName: String = "Arrow" + original.nodeName + + override def output: Seq[Attribute] = original.output +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowScanReplaceRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowScanReplaceRule.scala index 2b7c4b1da91b..adfc6ca742c9 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowScanReplaceRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowScanReplaceRule.scala @@ -17,18 +17,23 @@ package org.apache.gluten.extension import org.apache.gluten.datasource.ArrowCSVFileFormat +import org.apache.gluten.datasource.v2.ArrowCSVScan +import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ArrowFileSourceScanExec, FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec case class ArrowScanReplaceRule(spark: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { plan.transformUp { case plan: FileSourceScanExec if plan.relation.fileFormat.isInstanceOf[ArrowCSVFileFormat] => ArrowFileSourceScanExec(plan) + case plan: BatchScanExec if plan.scan.isInstanceOf[ArrowCSVScan] => + ArrowBatchScanExec(plan) + case plan: BatchScanExec => plan case p => p } - } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index bccb06a130ae..0872ac798382 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -18,6 +18,7 @@ package org.apache.gluten.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.datasource.ArrowCSVFileFormat +import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -491,7 +492,6 @@ class TestOperator extends VeloxWholeStageTransformerSuite { runQueryAndCompare("select * from student") { df => val plan = df.queryExecution.executedPlan - print(plan) assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined) assert(plan.find(_.isInstanceOf[ArrowFileSourceScanExec]).isDefined) val scan = plan.find(_.isInstanceOf[ArrowFileSourceScanExec]).toList.head @@ -538,6 +538,26 @@ class TestOperator extends VeloxWholeStageTransformerSuite { } } + test("csv scan datasource v2") { + withSQLConf("spark.sql.sources.useV1SourceList" -> "") { + val filePath = rootPath + "/datasource/csv/student.csv" + val df = spark.read + .format("csv") + .option("header", "true") + .load(filePath) + df.createOrReplaceTempView("student") + runQueryAndCompare("select * from student") { + checkGlutenOperatorMatch[ArrowBatchScanExec] + } + runQueryAndCompare("select * from student where Name = 'Peter'") { + df => + val plan = df.queryExecution.executedPlan + assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isEmpty) + assert(plan.find(s => s.isInstanceOf[ArrowBatchScanExec]).isDefined) + } + } + } + test("test OneRowRelation") { val df = sql("SELECT 1") checkAnswer(df, Row(1)) diff --git a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java index 624428dcba19..e2cfa335d5c6 100644 --- a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java +++ b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java @@ -19,7 +19,6 @@ import org.apache.gluten.exception.GlutenException; import org.apache.gluten.exec.Runtime; import org.apache.gluten.exec.Runtimes; -import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators; import org.apache.gluten.memory.nmm.NativeMemoryManager; import org.apache.gluten.utils.ArrowAbiUtil; import org.apache.gluten.utils.ArrowUtil; @@ -221,8 +220,7 @@ private static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch in final Runtime runtime = Runtimes.contextInstance(); try (ArrowArray cArray = ArrowArray.allocateNew(allocator); ArrowSchema cSchema = ArrowSchema.allocateNew(allocator)) { - ArrowAbiUtil.exportFromSparkColumnarBatch( - ArrowBufferAllocators.contextInstance(), input, cSchema, cArray); + ArrowAbiUtil.exportFromSparkColumnarBatch(allocator, input, cSchema, cArray); long handle = ColumnarBatchJniWrapper.forRuntime(runtime) .createWithArrowArray(cSchema.memoryAddress(), cArray.memoryAddress()); diff --git a/gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala b/gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala index 26bebcfae713..99eb72c70ea3 100644 --- a/gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala +++ b/gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala @@ -17,8 +17,6 @@ package org.apache.gluten.utils import org.apache.gluten.exception.SchemaMismatchException -import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators -import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.internal.Logging @@ -34,6 +32,7 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.arrow.c.{ArrowSchema, CDataDictionaryProvider, Data} import org.apache.arrow.dataset.file.{FileFormat, FileSystemDatasetFactory} +import org.apache.arrow.dataset.jni.NativeMemoryPool import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -140,19 +139,22 @@ object ArrowUtil extends Logging { rewritten.toString } - def makeArrowDiscovery(encodedUri: String, format: FileFormat): FileSystemDatasetFactory = { - val allocator = ArrowBufferAllocators.contextInstance() - val factory = new FileSystemDatasetFactory( - allocator, - ArrowNativeMemoryPool.arrowPool("FileSystemDatasetFactory"), - format, - rewriteUri(encodedUri)) + def makeArrowDiscovery( + encodedUri: String, + format: FileFormat, + allocator: BufferAllocator, + pool: NativeMemoryPool): FileSystemDatasetFactory = { + val factory = new FileSystemDatasetFactory(allocator, pool, format, rewriteUri(encodedUri)) factory } - def readSchema(file: FileStatus, format: FileFormat): Option[StructType] = { + def readSchema( + file: FileStatus, + format: FileFormat, + allocator: BufferAllocator, + pool: NativeMemoryPool): Option[StructType] = { val factory: FileSystemDatasetFactory = - makeArrowDiscovery(file.getPath.toString, format) + makeArrowDiscovery(file.getPath.toString, format, allocator, pool) val schema = factory.inspect() try { Option(SparkSchemaUtil.fromArrowSchema(schema)) @@ -161,12 +163,16 @@ object ArrowUtil extends Logging { } } - def readSchema(files: Seq[FileStatus], format: FileFormat): Option[StructType] = { + def readSchema( + files: Seq[FileStatus], + format: FileFormat, + allocator: BufferAllocator, + pool: NativeMemoryPool): Option[StructType] = { if (files.isEmpty) { throw new IllegalArgumentException("No input file specified") } - readSchema(files.head, format) + readSchema(files.head, format, allocator, pool) } def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean = { @@ -254,6 +260,7 @@ object ArrowUtil extends Logging { } def loadBatch( + allocator: BufferAllocator, input: ArrowRecordBatch, dataSchema: StructType, requiredSchema: StructType, @@ -267,7 +274,7 @@ object ArrowUtil extends Logging { rowCount, SparkSchemaUtil.toArrowSchema(dataSchema), input, - ArrowBufferAllocators.contextInstance()) + allocator) } finally { input.close() } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index dbd7dc187ba5..366796a57465 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -428,8 +428,15 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") - // file cars.csv include null string, Arrow not support to read .exclude("old csv data source name works") + .exclude("save csv") + .exclude("save csv with compression codec option") + .exclude("save csv with empty fields with user defined empty values") + .exclude("save csv with quote") + .exclude("SPARK-13543 Write the output as uncompressed via option()") + // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch + // Early Filter and Projection Push-Down generated an invalid plan + .exclude("SPARK-26208: write and read empty data to csv file with headers") enableSuite[GlutenCSVLegacyTimeParserSuite] .exclude("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") // file cars.csv include null string, Arrow not support to read diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 9b469a98d137..128e52a79b77 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -215,6 +215,15 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("Gluten - test for FAILFAST parsing mode") // file cars.csv include null string, Arrow not support to read .exclude("old csv data source name works") + .exclude("DDL test with schema") + .exclude("save csv") + .exclude("save csv with compression codec option") + .exclude("save csv with empty fields with user defined empty values") + .exclude("save csv with quote") + .exclude("SPARK-13543 Write the output as uncompressed via option()") + // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch + // Early Filter and Projection Push-Down generated an invalid plan + .exclude("SPARK-26208: write and read empty data to csv file with headers") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 1afa203ab6f5..6ea29847b0a6 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -195,6 +195,15 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("Gluten - test for FAILFAST parsing mode") // file cars.csv include null string, Arrow not support to read .exclude("old csv data source name works") + .exclude("DDL test with schema") + .exclude("save csv") + .exclude("save csv with compression codec option") + .exclude("save csv with empty fields with user defined empty values") + .exclude("save csv with quote") + .exclude("SPARK-13543 Write the output as uncompressed via option()") + // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch + // Early Filter and Projection Push-Down generated an invalid plan + .exclude("SPARK-26208: write and read empty data to csv file with headers") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 61353d99f7d1..e6e42acb31a2 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -196,8 +196,17 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") enableSuite[GlutenCSVv2Suite] .exclude("Gluten - test for FAILFAST parsing mode") + // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch + // Early Filter and Projection Push-Down generated an invalid plan + .exclude("SPARK-26208: write and read empty data to csv file with headers") // file cars.csv include null string, Arrow not support to read .exclude("old csv data source name works") + .exclude("DDL test with schema") + .exclude("save csv") + .exclude("save csv with compression codec option") + .exclude("save csv with empty fields with user defined empty values") + .exclude("save csv with quote") + .exclude("SPARK-13543 Write the output as uncompressed via option()") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/datasources/csv/GlutenCSVSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/datasources/csv/GlutenCSVSuite.scala index 38e6c9873ee0..cb7ce87f97da 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/datasources/csv/GlutenCSVSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/datasources/csv/GlutenCSVSuite.scala @@ -113,6 +113,7 @@ class GlutenCSVv2Suite extends GlutenCSVSuite { override def sparkConf: SparkConf = super.sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") + .set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true") override def testNameBlackList: Seq[String] = Seq( // overwritten with different test diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 4db784782c1e..e445dd33a585 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -100,3 +100,7 @@ abstract class BatchScanExecShim( ) } } + +abstract class ArrowBatchScanExecShim(original: BatchScanExec) extends DataSourceV2ScanExecBase { + @transient override lazy val partitions: Seq[InputPartition] = original.partitions +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 76556052c758..06eb69a35973 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -137,3 +137,9 @@ abstract class BatchScanExecShim( Boolean.box(replicatePartitions)) } } + +abstract class ArrowBatchScanExecShim(original: BatchScanExec) extends DataSourceV2ScanExecBase { + @transient override lazy val inputPartitions: Seq[InputPartition] = original.inputPartitions + + override def keyGroupedPartitioning: Option[Seq[Expression]] = original.keyGroupedPartitioning +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index ca9a7eb2d071..64afc8193f4e 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -150,3 +150,11 @@ abstract class BatchScanExecShim( } } } + +abstract class ArrowBatchScanExecShim(original: BatchScanExec) extends DataSourceV2ScanExecBase { + @transient override lazy val inputPartitions: Seq[InputPartition] = original.inputPartitions + + override def keyGroupedPartitioning: Option[Seq[Expression]] = original.keyGroupedPartitioning + + override def ordering: Option[Seq[SortOrder]] = original.ordering +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala index 47adf16fb0e7..8949a46a1ddd 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -152,3 +152,11 @@ abstract class BatchScanExecShim( } } } + +abstract class ArrowBatchScanExecShim(original: BatchScanExec) extends DataSourceV2ScanExecBase { + @transient override lazy val inputPartitions: Seq[InputPartition] = original.inputPartitions + + override def keyGroupedPartitioning: Option[Seq[Expression]] = original.keyGroupedPartitioning + + override def ordering: Option[Seq[SortOrder]] = original.ordering +}