diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 3c90ab5ea9531..024887e28d9e8 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -21,6 +21,7 @@ import io.glutenproject.exec.Runtimes import io.glutenproject.execution.BroadCastHashJoinContext import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators import io.glutenproject.memory.nmm.NativeMemoryManagers +import io.glutenproject.sql.shims._ import io.glutenproject.utils.{ArrowAbiUtil, Iterators} import io.glutenproject.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} @@ -48,7 +49,7 @@ case class ColumnarBuildSideRelation( val allocator = ArrowBufferAllocators.contextInstance() val cSchema = ArrowSchema.allocateNew(allocator) val arrowSchema = SparkArrowUtil.toArrowSchema( - StructType.fromAttributes(output), + SparkShimLoader.getSparkShims.structFromAttributes(output), SQLConf.get.sessionLocalTimeZone) ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) val handle = ColumnarBatchSerializerJniWrapper @@ -102,7 +103,7 @@ case class ColumnarBuildSideRelation( val allocator = ArrowBufferAllocators.globalInstance() val cSchema = ArrowSchema.allocateNew(allocator) val arrowSchema = SparkArrowUtil.toArrowSchema( - StructType.fromAttributes(output), + SparkShimLoader.getSparkShims.structFromAttributes(output), SQLConf.get.sessionLocalTimeZone) ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) val handle = serializerJniWrapper diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index 73dcb2e6cb5ca..f0c4cb7cf059b 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -21,7 +21,7 @@ import io.glutenproject.expression.Sig import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform @@ -50,6 +50,8 @@ trait SparkShims { // https://github.com/apache/spark/pull/32875 def getDistribution(leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[Distribution] + def structFromAttributes(attrs: Seq[Attribute]): StructType + def expressionMappings: Seq[Sig] def convertPartitionTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala index e77a0e6acb915..a22f5d1559ceb 100644 --- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala @@ -22,7 +22,7 @@ import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform @@ -45,6 +45,11 @@ class Spark32Shims extends SparkShims { HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil } + override def structFromAttributes( + attrs: Seq[Attribute]): StructType = { + StructType.fromAttributes(attrs) + } + override def expressionMappings: Seq[Sig] = Seq(Sig[Empty2Null](ExpressionNames.EMPTY2NULL)) override def convertPartitionTransforms( diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 8b9ec5fd23cb7..d43634746f91d 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -49,6 +49,11 @@ class Spark33Shims extends SparkShims { ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil } + override def structFromAttributes( + attrs: Seq[Attribute]): StructType = { + StructType.fromAttributes(attrs) + } + override def expressionMappings: Seq[Sig] = { val list = if (GlutenConfig.getConf.enableNativeBloomFilter) { Seq( diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala index bccece645811b..cc0d7ef928d62 100644 --- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala @@ -51,6 +51,11 @@ class Spark34Shims extends SparkShims { ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil } + override def structFromAttributes( + attrs: Seq[Attribute]): StructType = { + StructType.fromAttributes(attrs) + } + override def expressionMappings: Seq[Sig] = { val list = if (GlutenConfig.getConf.enableNativeBloomFilter) { Seq( diff --git a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala index 8b90a0faa9eff..ff643339d1374 100644 --- a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanLike, PartitionedFileUtil, SparkPlan} @@ -50,6 +51,10 @@ class Spark35Shims extends SparkShims { ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil } + override def structFromAttributes(attrs: Seq[Attribute]): StructType = { + DataTypeUtils.fromAttributes(attrs) + } + override def expressionMappings: Seq[Sig] = { val list = if (GlutenConfig.getConf.enableNativeBloomFilter) { Seq(