Skip to content

Commit

Permalink
Add shim for fromAttributes
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Nov 17, 2023
1 parent 32f8cfb commit c9c7173
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.glutenproject.exec.Runtimes
import io.glutenproject.execution.BroadCastHashJoinContext
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.sql.shims._
import io.glutenproject.utils.{ArrowAbiUtil, Iterators}
import io.glutenproject.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

Expand Down Expand Up @@ -48,7 +49,7 @@ case class ColumnarBuildSideRelation(
val allocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
val arrowSchema = SparkArrowUtil.toArrowSchema(
StructType.fromAttributes(output),
SparkShimLoader.getSparkShims.structFromAttributes(output),
SQLConf.get.sessionLocalTimeZone)
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
val handle = ColumnarBatchSerializerJniWrapper
Expand Down Expand Up @@ -102,7 +103,7 @@ case class ColumnarBuildSideRelation(
val allocator = ArrowBufferAllocators.globalInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
val arrowSchema = SparkArrowUtil.toArrowSchema(
StructType.fromAttributes(output),
SparkShimLoader.getSparkShims.structFromAttributes(output),
SQLConf.get.sessionLocalTimeZone)
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
val handle = serializerJniWrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.expression.Sig
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PlanExpression}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
Expand Down Expand Up @@ -50,6 +50,8 @@ trait SparkShims {
// https://github.com/apache/spark/pull/32875
def getDistribution(leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[Distribution]

def structFromAttributes(attrs: Seq[Attribute]): StructType

def expressionMappings: Seq[Sig]

def convertPartitionTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
Expand All @@ -45,6 +45,11 @@ class Spark32Shims extends SparkShims {
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
}

override def structFromAttributes(
attrs: Seq[Attribute]): StructType = {
StructType.fromAttributes(attrs)
}

override def expressionMappings: Seq[Sig] = Seq(Sig[Empty2Null](ExpressionNames.EMPTY2NULL))

override def convertPartitionTransforms(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class Spark33Shims extends SparkShims {
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}

override def structFromAttributes(
attrs: Seq[Attribute]): StructType = {
StructType.fromAttributes(attrs)
}

override def expressionMappings: Seq[Sig] = {
val list = if (GlutenConfig.getConf.enableNativeBloomFilter) {
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class Spark34Shims extends SparkShims {
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}

override def structFromAttributes(
attrs: Seq[Attribute]): StructType = {
StructType.fromAttributes(attrs)
}

override def expressionMappings: Seq[Sig] = {
val list = if (GlutenConfig.getConf.enableNativeBloomFilter) {
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanLike, PartitionedFileUtil, SparkPlan}
Expand All @@ -50,6 +51,10 @@ class Spark35Shims extends SparkShims {
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}

override def structFromAttributes(attrs: Seq[Attribute]): StructType = {
DataTypeUtils.fromAttributes(attrs)
}

override def expressionMappings: Seq[Sig] = {
val list = if (GlutenConfig.getConf.enableNativeBloomFilter) {
Seq(
Expand Down

0 comments on commit c9c7173

Please sign in to comment.