diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 8a9e37a13b3f..9e0230b3f6a8 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.CHExecUtil import org.apache.spark.sql.extension.ClickHouseAnalysis +import org.apache.spark.sql.extension.RewriteDateTimestampComparisonRule import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -329,7 +330,12 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { - List(spark => new ClickHouseAnalysis(spark, spark.sessionState.conf)) + val analyzers = List(spark => new ClickHouseAnalysis(spark, spark.sessionState.conf)) + if (GlutenConfig.getConf.enableDateTimestampComparison) { + analyzers :+ (spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) + } else { + analyzers + } } /** diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala index 817699491b7f..e7e5e317a9b9 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala @@ -113,26 +113,27 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS val dateSchema = StructType( Array( + StructField("ts", IntegerType, true), StructField("day", DateType, true), StructField("weekday_abbr", StringType, true) ) ) val dateRows = sparkContext.parallelize( Seq( - Row(Date.valueOf("2019-01-01"), "MO"), - Row(Date.valueOf("2019-01-01"), "TU"), - Row(Date.valueOf("2019-01-01"), "TH"), - Row(Date.valueOf("2019-01-01"), "WE"), - Row(Date.valueOf("2019-01-01"), "FR"), - Row(Date.valueOf("2019-01-01"), "SA"), - Row(Date.valueOf("2019-01-01"), "SU"), - Row(Date.valueOf("2019-01-01"), "MO"), - Row(Date.valueOf("2019-01-02"), "MM"), - Row(Date.valueOf("2019-01-03"), "TH"), - Row(Date.valueOf("2019-01-04"), "WE"), - Row(Date.valueOf("2019-01-05"), "FR"), - Row(null, "SA"), - Row(Date.valueOf("2019-01-07"), null) + Row(1546309380, Date.valueOf("2019-01-01"), "MO"), + Row(1546273380, Date.valueOf("2019-01-01"), "TU"), + Row(1546358340, Date.valueOf("2019-01-01"), "TH"), + Row(1546311540, Date.valueOf("2019-01-01"), "WE"), + Row(1546308540, Date.valueOf("2019-01-01"), "FR"), + Row(1546319340, Date.valueOf("2019-01-01"), "SA"), + Row(1546319940, Date.valueOf("2019-01-01"), "SU"), + Row(1546323545, Date.valueOf("2019-01-01"), "MO"), + Row(1546409940, Date.valueOf("2019-01-02"), "MM"), + Row(1546496340, Date.valueOf("2019-01-03"), "TH"), + Row(1546586340, Date.valueOf("2019-01-04"), "WE"), + Row(1546676341, Date.valueOf("2019-01-05"), "FR"), + Row(null, null, "SA"), + Row(1546849141, Date.valueOf("2019-01-07"), null) ) ) val dateTableFile = Files.createTempFile("", ".parquet").toFile @@ -466,4 +467,22 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS "select round(0.41875f * id , 4) from range(10);" )(checkOperatorMatch[ProjectExecTransformer]) } + + test("test date comparision expression override") { + runQueryAndCompare( + "select * from date_table where to_date(from_unixtime(ts)) < '2019-01-02'", + noFallBack = true) { _ => } + runQueryAndCompare( + "select * from date_table where to_date(from_unixtime(ts)) <= '2019-01-02'", + noFallBack = true) { _ => } + runQueryAndCompare( + "select * from date_table where to_date(from_unixtime(ts)) > '2019-01-02'", + noFallBack = true) { _ => } + runQueryAndCompare( + "select * from date_table where to_date(from_unixtime(ts)) >= '2019-01-02'", + noFallBack = true) { _ => } + runQueryAndCompare( + "select * from date_table where to_date(from_unixtime(ts)) = '2019-01-01'", + noFallBack = true) { _ => } + } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala new file mode 100644 index 000000000000..5bbc9fb1820f --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.extension + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.lang.IllegalArgumentException + +// For readable, people usually convert a unix timestamp into date, and compare it with another +// date. For example +// select * from table where '2023-11-02' >= from_unixtime(unix_timestamp, 'yyyy-MM-dd') +// There are performance shortcomings +// 1. convert a unix timestamp into date is expensive +// 2. comparisoin with date or string is not efficient. +// +// This rule try to make the filter condition into integer comparison, which is more efficient. +// The above example will be rewritten into +// select * from table where to_unixtime('2023-11-02', 'yyyy-MM-dd') >= unix_timestamp +class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) + extends Rule[LogicalPlan] + with Logging { + + object TimeUnit extends Enumeration { + val SECOND, MINUTE, HOUR, DAY, MONTH, YEAR = Value + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (plan.resolved) { + visitPlan(plan) + } else { + plan + } + } + + private def visitPlan(plan: LogicalPlan): LogicalPlan = plan match { + case filter: Filter => + val newCondition = visitExpression(filter.condition) + val newFilter = Filter(newCondition, filter.child) + newFilter + case other => + val children = other.children.map(visitPlan) + other.withNewChildren(children) + } + + private def visitExpression(expression: Expression): Expression = expression match { + case cmp: BinaryComparison => + if (isConstDateExpression(cmp.left) && isDateFromUnixTimestamp(cmp.right)) { + rewriteComparisionBetweenTimestampAndDate(exchangeOperators(cmp)) + } else if (isConstDateExpression(cmp.right) && isDateFromUnixTimestamp(cmp.left)) { + rewriteComparisionBetweenTimestampAndDate(cmp) + } else { + val children = expression.children.map(visitExpression) + expression.withNewChildren(children) + } + case literal: Literal => + expression + case _ => + val children = expression.children.map(visitExpression) + expression.withNewChildren(children) + } + + private def isConstDateExpression(expression: Expression): Boolean = { + def allConstExpression(expr: Expression): Boolean = expr match { + case literal: Literal => true + case attr: Attribute => false + case _ => expr.children.forall(allConstExpression) + } + if ( + !expression.dataType.isInstanceOf[DateType] && !expression.dataType.isInstanceOf[StringType] + ) { + return false + } + if (!allConstExpression(expression)) { + return false + } + true + } + + private def isDateFromUnixTimestamp(expr: Expression): Boolean = + expr match { + case toDate: ParseToDate => + isDateFromUnixTimestamp(toDate.left) + case fromUnixTime: FromUnixTime => + true + case _ => false + } + + private def getDateTimeUnit(expr: Expression): Option[TimeUnit.Value] = { + expr match { + case toDate: ParseToDate => + val timeUnit = if (toDate.format.isEmpty) { + Some(TimeUnit.DAY) + } else { + getDateTimeUnitFromLiteral(toDate.format) + } + val nestedTimeUnit = getDateTimeUnit(toDate.left) + if (nestedTimeUnit.isEmpty) { + timeUnit + } else { + if (nestedTimeUnit.get > timeUnit.get) { + nestedTimeUnit + } else { + timeUnit + } + } + case fromUnixTime: FromUnixTime => + getDateTimeUnitFromLiteral(Some(fromUnixTime.format)) + case _ => None + } + } + + private def getDateTimeUnitFromLiteral(expr: Option[Expression]): Option[TimeUnit.Value] = { + if (expr.isEmpty) { + Some(TimeUnit.SECOND) + } else if ( + !expr.get.isInstanceOf[Literal] || !expr.get + .asInstanceOf[Literal] + .dataType + .isInstanceOf[StringType] + ) { + None + } else { + val formatExpr = expr.get.asInstanceOf[Literal] + val formatStr = formatExpr.value.asInstanceOf[UTF8String].toString + if (formatStr.contains("ss")) { + Some(TimeUnit.SECOND) + } else if (formatStr.contains("mm")) { + Some(TimeUnit.MINUTE) + } else if (formatStr.contains("HH")) { + Some(TimeUnit.HOUR) + } else if (formatStr.contains("dd")) { + Some(TimeUnit.DAY) + } else if (formatStr.contains("MM")) { + Some(TimeUnit.MONTH) + } else if (formatStr.contains("yyyy")) { + Some(TimeUnit.YEAR) + } else { + None + } + } + } + + private def getTimeZoneId(expr: Expression): Option[String] = { + expr match { + case toDate: ParseToDate => + getTimeZoneId(toDate.left) + case fromUnixTime: FromUnixTime => + fromUnixTime.timeZoneId + case _ => None + } + } + + private def timeUnitToFormat(timeUnit: TimeUnit.Value): String = { + timeUnit match { + case TimeUnit.SECOND => "yyyy-MM-dd HH:mm:ss" + case TimeUnit.MINUTE => "yyyy-MM-dd HH:mm" + case TimeUnit.HOUR => "yyyy-MM-dd HH" + case TimeUnit.DAY => "yyyy-MM-dd" + case TimeUnit.MONTH => "yyyy-MM" + case TimeUnit.YEAR => "yyyy" + } + } + + private def rewriteConstDate( + expr: Expression, + timeUnit: TimeUnit.Value, + zoneId: Option[String], + adjustedOffset: Long): Expression = { + val formatExpr = Literal(UTF8String.fromString(timeUnitToFormat(timeUnit)), StringType) + val adjustExpr = Literal(adjustedOffset, LongType) + val toUnixTimestampExpr = ToUnixTimestamp(expr, formatExpr, zoneId) + Add(toUnixTimestampExpr, adjustExpr) + } + + private def rewriteUnixTimestampToDate(expr: Expression): Expression = { + expr match { + case toDate: ParseToDate => + rewriteUnixTimestampToDate(toDate.left) + case fromUnixTime: FromUnixTime => + fromUnixTime.sec + case _ => throw new IllegalArgumentException(s"Invalid expression: $expr") + } + } + + private def exchangeOperators(cmp: BinaryComparison): BinaryComparison = { + cmp match { + case gt: GreaterThan => + LessThan(cmp.right, cmp.left) + case gte: GreaterThanOrEqual => + LessThanOrEqual(cmp.right, cmp.left) + case lt: LessThan => + GreaterThan(cmp.right, cmp.left) + case lte: LessThanOrEqual => + GreaterThanOrEqual(cmp.right, cmp.left) + case eq: EqualTo => + EqualTo(cmp.right, cmp.left) + case eqn: EqualNullSafe => + EqualNullSafe(cmp.right, cmp.left) + } + } + + private def rewriteComparisionBetweenTimestampAndDate(cmp: BinaryComparison): Expression = { + val res = cmp match { + case gt: GreaterThan => + rewriteGreaterThen(gt) + case gte: GreaterThanOrEqual => + rewriteGreaterThanOrEqual(gte) + case lt: LessThan => + rewriteLessThen(lt) + case lte: LessThanOrEqual => + rewriteLessThenOrEqual(lte) + case eq: EqualTo => + rewriteEqualTo(eq) + case eqn: EqualNullSafe => + rewriteEqualNullSafe(eqn) + } + logInfo(s"rewrite expresion $cmp to $res") + res + } + + def TimeUnitToSeconds(timeUnit: TimeUnit.Value): Long = timeUnit match { + case TimeUnit.SECOND => 1 + case TimeUnit.MINUTE => 60 + case TimeUnit.HOUR => 3600 + case TimeUnit.DAY => 86400 + case TimeUnit.MONTH => 2592000 + case TimeUnit.YEAR => 31536000 + } + + private def rewriteGreaterThen(cmp: GreaterThan): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val adjustedOffset = TimeUnitToSeconds(timeUnit.get) + val newLeft = rewriteUnixTimestampToDate(cmp.left) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, adjustedOffset) + GreaterThanOrEqual(newLeft, newRight) + } + + private def rewriteGreaterThanOrEqual(cmp: GreaterThanOrEqual): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val adjustedOffset = 0 + val newLeft = rewriteUnixTimestampToDate(cmp.left) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, adjustedOffset) + GreaterThanOrEqual(newLeft, newRight) + } + + private def rewriteLessThen(cmp: LessThan): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val adjustedOffset = 0 + val newLeft = rewriteUnixTimestampToDate(cmp.left) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, adjustedOffset) + LessThan(newLeft, newRight) + } + + private def rewriteLessThenOrEqual(cmp: LessThanOrEqual): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val adjustedOffset = TimeUnitToSeconds(timeUnit.get) + val newLeft = rewriteUnixTimestampToDate(cmp.left) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, adjustedOffset) + LessThan(newLeft, newRight) + } + + private def rewriteEqualTo(cmp: EqualTo): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val timestampLeft = rewriteUnixTimestampToDate(cmp.left) + val adjustedOffset = Literal(TimeUnitToSeconds(timeUnit.get), timestampLeft.dataType) + val addjustedOffsetExpr = Remainder(timestampLeft, adjustedOffset) + val newLeft = Subtract(timestampLeft, addjustedOffsetExpr) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, 0) + EqualTo(newLeft, newRight) + } + + private def rewriteEqualNullSafe(cmp: EqualNullSafe): Expression = { + val timeUnit = getDateTimeUnit(cmp.left) + if (timeUnit.isEmpty) { + return cmp + } + val zoneId = getTimeZoneId(cmp.left) + val timestampLeft = rewriteUnixTimestampToDate(cmp.left) + val adjustedOffset = Literal(TimeUnitToSeconds(timeUnit.get), timestampLeft.dataType) + val addjustedOffsetExpr = Remainder(timestampLeft, adjustedOffset) + val newLeft = Subtract(timestampLeft, addjustedOffsetExpr) + val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, 0) + EqualNullSafe(newLeft, newRight) + } +} diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index ee1c5fa5431d..0299637dc531 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -83,6 +83,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def columnarTableCacheEnabled: Boolean = conf.getConf(COLUMNAR_TABLE_CACHE_ENABLED) + def enableDateTimestampComparison: Boolean = conf.getConf(ENABLE_DATE_TIMESTAMP_COMPARISON) + // whether to use ColumnarShuffleManager def isUseColumnarShuffleManager: Boolean = conf @@ -1267,4 +1269,12 @@ object GlutenConfig { + "partial aggregation may be early abandoned.") .intConf .createOptional + + val ENABLE_DATE_TIMESTAMP_COMPARISON = + buildConf("spark.gluten.sql.rewrite.dateTimestampComparison") + .internal() + .doc("Rewrite the comparision between date and timestamp to timestamp comparison." + + "For example `fron_unixtime(ts) > date` will be rewritten to `ts > to_unixtime(date)`") + .booleanConf + .createWithDefault(true) }