diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index 0fb5fb54900b..86a62a4471a5 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -354,6 +354,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla Seq("sort", "streaming").foreach { windowType => withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) { + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_commitdate" + + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala index 73b8ab2607eb..51cdb76a1559 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala @@ -76,17 +76,7 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { } case _ => false }.isDefined) || - window.windowExpression.exists(_.find { - case we: WindowExpression => - we.windowSpec.frameSpecification match { - case swf: SpecifiedWindowFrame - if needPreComputeRangeFrame(swf) && supportPreComputeRangeFrame( - we.windowSpec.orderSpec) => - true - case _ => false - } - case _ => false - }.isDefined) + windowNeedPreComputeRangeFrame(window) case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => val window = SparkShimLoader.getSparkShims .getWindowGroupLimitExecShim(plan) @@ -176,14 +166,16 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { case window: WindowExec if needsPreProject(window) => val expressionMap = new mutable.HashMap[Expression, NamedExpression]() - // Handle orderSpec. - val newOrderSpec = getNewSortOrder(window.orderSpec, expressionMap) - - // Handle partitionSpec. + // Handle foldable orderSpec and foldable partitionSpec. Spark analyzer rule + // ExtractWindowExpressions will extract expressions from non-foldable orderSpec and + // partitionSpec. + var newOrderSpec = getNewSortOrder(window.orderSpec, expressionMap) val newPartitionSpec = window.partitionSpec.map(replaceExpressionWithAttribute(_, expressionMap)) // Handle windowExpressions. + newOrderSpec = rewriteOrderSpecs(window, newOrderSpec, expressionMap) + val newWindowExpressions = window.windowExpression.toIndexedSeq.map { _.transform { case we: WindowExpression => rewriteWindowExpression(we, newOrderSpec, expressionMap) diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index 12055f9e9721..85be57493f02 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -22,8 +22,10 @@ import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} +import java.sql.Date import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable @@ -161,14 +163,32 @@ trait PullOutProjectHelper { case _: PreComputeRangeFrameBound => bound case _ if !bound.foldable => bound case _ if bound.foldable => + val orderExpr = if (expressionMap.contains(orderSpec.child)) { + expressionMap(orderSpec.child).asInstanceOf[Alias].child + } else { + orderSpec.child + } val a = expressionMap .getOrElseUpdate( bound.canonicalized, - Alias(Add(orderSpec.child, bound), generatePreAliasName)()) + Alias(Add(orderExpr, bound), generatePreAliasName)()) PreComputeRangeFrameBound(a.asInstanceOf[Alias], bound) } } + protected def windowNeedPreComputeRangeFrame(w: WindowExec): Boolean = + w.windowExpression.exists(_.find { + case we: WindowExpression => + we.windowSpec.frameSpecification match { + case swf: SpecifiedWindowFrame + if needPreComputeRangeFrame(swf) && supportPreComputeRangeFrame( + we.windowSpec.orderSpec) => + true + case _ => false + } + case _ => false + }.isDefined) + protected def needPreComputeRangeFrame(swf: SpecifiedWindowFrame): Boolean = { BackendsApiManager.getSettings.needPreComputeRangeFrameBoundary && swf.frameType == RangeFrame && @@ -185,6 +205,36 @@ trait PullOutProjectHelper { } } + /** + * Convert DateType to IntType for orderSpec if needPreComputeRangeFrame, because spark's frame + * type does not support DateType. It does not affect the correctness of sort. + */ + protected def rewriteOrderSpecs( + window: WindowExec, + orderSpecs: Seq[SortOrder], + expressionMap: mutable.HashMap[Expression, NamedExpression]): Seq[SortOrder] = { + if (windowNeedPreComputeRangeFrame(window)) { + // This is guaranteed by Spark, but we still check it here + if (orderSpecs.size != 1) { + throw new GlutenException( + s"A range window frame with value boundaries expects one and only one " + + s"order by expression: ${orderSpecs.mkString(",")}") + } + val orderSpec = orderSpecs.head + orderSpec.child.dataType match { + case DateType => + val alias = Alias( + DateDiff(orderSpec.child, Literal(Date.valueOf("1970-01-01"))), + generatePreAliasName)() + expressionMap.getOrElseUpdate(alias.toAttribute, alias) + Seq(orderSpec.copy(child = alias.toAttribute)) + case _ => orderSpecs + } + } else { + orderSpecs + } + } + protected def rewriteWindowExpression( we: WindowExpression, orderSpecs: Seq[SortOrder], @@ -202,12 +252,6 @@ trait PullOutProjectHelper { val newWindowSpec = we.windowSpec.frameSpecification match { case swf: SpecifiedWindowFrame if needPreComputeRangeFrame(swf) => - // This is guaranteed by Spark, but we still check it here - if (orderSpecs.size != 1) { - throw new GlutenException( - s"A range window frame with value boundaries expects one and only one " + - s"order by expression: ${orderSpecs.mkString(",")}") - } val orderSpec = orderSpecs.head val lowerFrameCol = preComputeRangeFrameBoundary(swf.lower, orderSpec, expressionMap) val upperFrameCol = preComputeRangeFrameBoundary(swf.upper, orderSpec, expressionMap)