Skip to content

Commit

Permalink
refactor 1101
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 6, 2024
1 parent 98b49ae commit 6b458b6
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ private object CHRuleApi {
// Inject the regular Spark rules directly.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
// injector.injectQueryStagePrepRule(spark => LazyExpandRule(spark))
injector.injectParser(
(spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface))
injector.injectParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ import org.apache.spark.sql.types._

case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
logDebug(
s"xxx 1031 enable lazy aggregate expand: " +
s"${CHBackendSettings.enableLazyAggregateExpand}")
logDebug(s"xxx enable lazy aggregate expand: ${CHBackendSettings.enableLazyAggregateExpand}")
if (!CHBackendSettings.enableLazyAggregateExpand) {
return plan
}
Expand Down Expand Up @@ -82,18 +80,18 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {
if (doValidation(partialAggregate, expand, shuffle)) {

val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
val newPartialAggregate = buildAheadAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
val newExpand = buildPostExpandExec(
expand,
partialAggregate,
newPartialAggregate,
Expand Down Expand Up @@ -129,17 +127,17 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {
if (doValidation(partialAggregate, expand, shuffle)) {
val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
val newPartialAggregate = buildAheadAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
val newExpand = buildPostExpandExec(
expand,
partialAggregate,
newPartialAggregate,
Expand All @@ -164,83 +162,92 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
// 2. select n_name, count(distinct n_regionkey) as col1,
// count(distinct concat(n_regionkey, n_nationkey)) as col2 from
// nation group by n_name;
def isSupportedAggregate(
def doValidation(
aggregate: CHHashAggregateExecTransformer,
expand: ExpandExecTransformer,
shuffle: ColumnarShuffleExchangeExec): Boolean = {
// all grouping keys must be attribute references
val expandOutputAttributes = expand.child.output.toSet
if (aggregate.groupingExpressions.exists(!_.isInstanceOf[Attribute])) {
if (
!aggregate.groupingExpressions.forall(
e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal])
) {
logDebug(s"xxx Not all grouping expression are attribute references")
return false
}
// all shuffle keys must be attribute references
if (
shuffle.outputPartitioning
!shuffle.outputPartitioning
.asInstanceOf[HashPartitioning]
.expressions
.exists(!_.isInstanceOf[Attribute])
.forall(e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal])
) {
logDebug(s"xxx Not all shuffle hash expression are attribute references")
return false
}

// 1. for safty, we don't enbale this optimization for all aggregate functions.
// 2. if any aggregate function uses attributes from expand's child, we don't enable this
// 2. if any aggregate function uses attributes which is not from expand's child, we don't
// enable this
if (
!aggregate.aggregateExpressions.forall(
e =>
isSupportedAggregateFunction(e) &&
isValidAggregateFunction(e) &&
e.aggregateFunction.references.forall(expandOutputAttributes.contains(_)))
) {
logDebug(s"xxx Some aggregate functions are not supported")
return false
}

// ensure the last column of expand is grouping id
val groupIdIndex = findGroupingIdIndex(expand)
logDebug(s"xxx Find group id at index: $groupIdIndex")
if (groupIdIndex == -1) {
return false;
}
val groupIdAttribute = expand.output(groupIdIndex)
if (
!groupIdAttribute.name.startsWith("grouping_id") && !groupIdAttribute.name.startsWith("gid")
&& !groupIdAttribute.name.startsWith("spark_grouping_id")
) {
logDebug(s"xxx Not found group id column at index $groupIdIndex")
return false
}
expand.projections.forall {
projection =>
val groupId = projection(groupIdIndex)
groupId
.isInstanceOf[Literal] && (groupId.dataType.isInstanceOf[LongType] || groupId.dataType
.isInstanceOf[IntegerType])
}
// get the group id's position in the expand's output
val gidIndex = findGroupingIdIndex(expand)
gidIndex != -1
}

// group id column doesn't have a fixed position, so we need to find it.
def findGroupingIdIndex(expand: ExpandExecTransformer): Int = {
def isValidGroupIdColumn(e: Expression, gids: Set[Long]): Long = {
if (!e.isInstanceOf[Literal]) {
return -1
}
val literalValue = e.asInstanceOf[Literal].value
e.dataType match {
case _: LongType =>
if (gids.contains(literalValue.asInstanceOf[Long])) {
-1
} else {
literalValue.asInstanceOf[Long]
}
case _: IntegerType =>
if (gids.contains(literalValue.asInstanceOf[Int].toLong)) {
-1
} else {
literalValue.asInstanceOf[Int].toLong
}
case _ => -1
}
}

var groupIdIndexes = Seq[Int]()
for (col <- 0 until expand.output.length) {
val expandCol = expand.projections(0)(col)
if (
expandCol.isInstanceOf[Literal] && (expandCol.dataType
.isInstanceOf[LongType] || expandCol.dataType.isInstanceOf[IntegerType])
) {
// gids should be unique
var gids = Set[Long]()
if (isValidGroupIdColumn(expandCol, gids) != -1) {
if (
expand.projections.forall {
projection =>
val e = projection(col)
e.isInstanceOf[Literal] &&
(e.dataType.isInstanceOf[LongType] || e.dataType.isInstanceOf[IntegerType])
val res = isValidGroupIdColumn(projection(col), gids)
gids += res
res != -1
}
) {
groupIdIndexes +:= col
}
}
}
if (groupIdIndexes.length == 1) {
logDebug(s"xxx gid is at pos ${groupIdIndexes(0)}")
groupIdIndexes(0)
} else {
-1
Expand All @@ -252,7 +259,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
// column, avg.
// - sum: if the input's type is decimal, the output are sum and isEmpty, but gluten doesn't use
// the isEmpty column.
def isSupportedAggregateFunction(aggregateExpression: AggregateExpression): Boolean = {
def isValidAggregateFunction(aggregateExpression: AggregateExpression): Boolean = {
if (aggregateExpression.filter.isDefined) {
return false
}
Expand Down Expand Up @@ -290,7 +297,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
attributeMap
}

def buildNewExpandProjections(
def buildPostExpandProjections(
originalExpandProjections: Seq[Seq[Expression]],
originalExpandOutput: Seq[Attribute],
newExpandOutput: Seq[Attribute]): Seq[Seq[Expression]] = {
Expand All @@ -309,49 +316,39 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
newExpandProjections
}

// Make the child of expand be the child of aggregate
// Need to replace some attributes
def buildNewAggregateExec(
// 1. make expand's child be aggregate's child
// 2. replace the attributes in groupingExpressions and resultExpressions as needed
def buildAheadAggregateExec(
partialAggregate: CHHashAggregateExecTransformer,
expand: ExpandExecTransformer,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
val groupIdAttribute = expand.output(findGroupingIdIndex(expand))

// if the grouping keys contains literal, they should not be in attributesToReplace
// And we need to remove them from the grouping keys
val groupingExpressions = partialAggregate.groupingExpressions
val newGroupingExpresion =
groupingExpressions
// New grouping expressions should include the group id column
val groupingExpressions =
partialAggregate.groupingExpressions
.filter(
e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute))
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
logDebug(
s"xxx newGroupingExpresion: $newGroupingExpresion,\n" +
s"groupingExpressions: $groupingExpressions")
s"xxx newGroupingExpresion: $groupingExpressions,\n" +
s"groupingExpressions: ${partialAggregate.groupingExpressions}")

// Also need to remove literal grouping keys from the result expressions
// Remove group id column from result expressions
val resultExpressions = partialAggregate.resultExpressions
val newResultExpressions =
resultExpressions
.filter {
e =>
e.toAttribute != groupIdAttribute && (groupingExpressions
.find(_.toAttribute == e.toAttribute)
.isEmpty || attributesToReplace.contains(e.toAttribute))
}
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
.filter(_.toAttribute != groupIdAttribute)
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
logDebug(
s"xxx newResultExpressions: $newResultExpressions\n" +
s"resultExpressions:$resultExpressions")
s"xxx newResultExpressions: $resultExpressions\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
partialAggregate.copy(
groupingExpressions = newGroupingExpresion,
resultExpressions = newResultExpressions,
groupingExpressions = groupingExpressions,
resultExpressions = resultExpressions,
child = expand.child)
}

def buildNewExpandExec(
def buildPostExpandExec(
expand: ExpandExecTransformer,
partialAggregate: CHHashAggregateExecTransformer,
child: SparkPlan,
Expand All @@ -361,13 +358,13 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
logDebug(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}")
logDebug(s"xxx aggregateOutput: $aggregateOutput")

val newExpandProjections = buildNewExpandProjections(
val expandProjections = buildPostExpandProjections(
expand.projections,
expand.output,
aggregateOutput
)
logDebug(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}")
ExpandExecTransformer(newExpandProjections, aggregateOutput, child)
logDebug(s"xxx expandProjections: $expandProjections\nprojections:${expand.projections}")
ExpandExecTransformer(expandProjections, aggregateOutput, child)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3024,47 +3024,28 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
}

test("GLLUTEN-7647 lazy expand") {
def checkLazyExpand(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer if (e.child.isInstanceOf[HashAggregateExecBaseTransformer]) =>
e
}
assert(expands.size == 1)
}
var sql =
"""
|select n_regionkey, n_nationkey,
|sum(n_regionkey), count(n_name), max(n_regionkey), min(n_regionkey)
|from nation group by n_regionkey, n_nationkey with cube
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)

sql = """
|select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name)
|from nation group by n_regionkey, n_nationkey with cube
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select n_regionkey, n_nationkey,
|sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey)
|from nation group by n_regionkey, n_nationkey with cube
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql =
"""
|select n_regionkey, n_nationkey,
|sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey)
|from nation group by n_regionkey, n_nationkey grouping sets((n_regionkey), (n_nationkey))
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select n_regionkey, n_nationkey,
|sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey)
|from nation group by n_regionkey, n_nationkey
|grouping sets((n_regionkey, null), (null, n_nationkey))
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)

sql = """
|select * from(
Expand All @@ -3074,7 +3055,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
|) where n_regionkey != 0
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)

sql = """
|select * from(
Expand All @@ -3084,7 +3065,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
|) where n_regionkey != 0
|order by n_regionkey, n_nationkey
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
}
}
// scalastyle:on line.size.limit

0 comments on commit 6b458b6

Please sign in to comment.