Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Oct 31, 2024
1 parent db5ed73 commit b152a26
Showing 1 changed file with 34 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.apache.spark.sql.types._

case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
logError(
logDebug(
s"xxx 1031 enable lazy aggregate expand: " +
s"${CHBackendSettings.enableLazyAggregateExpand}")
if (!CHBackendSettings.enableLazyAggregateExpand) {
Expand All @@ -73,23 +73,19 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
_,
_
) =>
logError(s"xxx match plan:$shuffle")
logDebug(s"xxx match plan:$shuffle")
val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer]
val expand = partialAggregate.child.asInstanceOf[ExpandExecTransformer]
logError(
logDebug(
s"xxx partialAggregate: groupingExpressions:" +
s"${partialAggregate.groupingExpressions}\n" +
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {

val attributesToReplace = buildReplaceAttributeMapForAggregate(
groupingExpressions,
projections,
output
)
logError(s"xxx attributesToReplace: $attributesToReplace")
val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
partialAggregate,
Expand All @@ -105,7 +101,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
)

val newShuffle = shuffle.copy(child = newExpand)
logError(s"xxx new plan: $newShuffle")
logDebug(s"xxx new plan: $newShuffle")
newShuffle
} else {
shuffle
Expand All @@ -127,19 +123,15 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer]
val filter = partialAggregate.child.asInstanceOf[FilterExecTransformer]
val expand = filter.child.asInstanceOf[ExpandExecTransformer]
logError(
logDebug(
s"xxx partialAggregate: groupingExpressions:" +
s"${partialAggregate.groupingExpressions}\n" +
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {
val attributesToReplace = buildReplaceAttributeMapForAggregate(
groupingExpressions,
projections,
output
)
logError(s"xxx attributesToReplace: $attributesToReplace")
val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
partialAggregate,
Expand All @@ -157,7 +149,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
val newFilter = filter.copy(child = newExpand)

val newShuffle = shuffle.copy(child = newFilter)
logError(s"xxx new plan: $newShuffle")
logDebug(s"xxx new plan: $newShuffle")
newShuffle

} else {
Expand All @@ -179,7 +171,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
// all grouping keys must be attribute references
val expandOutputAttributes = expand.child.output.toSet
if (aggregate.groupingExpressions.exists(!_.isInstanceOf[Attribute])) {
logError(s"xxx Not all grouping expression are attribute references")
logDebug(s"xxx Not all grouping expression are attribute references")
return false
}
// all shuffle keys must be attribute references
Expand All @@ -189,26 +181,25 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
.expressions
.exists(!_.isInstanceOf[Attribute])
) {
logError(s"xxx Not all shuffle hash expression are attribute references")
logDebug(s"xxx Not all shuffle hash expression are attribute references")
return false
}

// For safety, only enalbe for some aggregate functions
// All the parameters in aggregate functions must be the references of the output of expand's
// child
// 1. for safty, we don't enbale this optimization for all aggregate functions.
// 2. if any aggregate function uses attributes from expand's child, we don't enable this
if (
!aggregate.aggregateExpressions.forall(
e =>
isSupportedAggregateFunction(e) && e.aggregateFunction.references.forall(
expandOutputAttributes.contains(_)))
isSupportedAggregateFunction(e) &&
e.aggregateFunction.references.forall(expandOutputAttributes.contains(_)))
) {
logError(s"xxx Some aggregate functions are not supported")
logDebug(s"xxx Some aggregate functions are not supported")
return false
}

// ensure the last column of expand is grouping id
val groupIdIndex = findGroupingIdIndex(expand)
logError(s"xxx Find group id at index: $groupIdIndex")
logDebug(s"xxx Find group id at index: $groupIdIndex")
if (groupIdIndex == -1) {
return false;
}
Expand All @@ -217,7 +208,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
!groupIdAttribute.name.startsWith("grouping_id") && !groupIdAttribute.name.startsWith("gid")
&& !groupIdAttribute.name.startsWith("spark_grouping_id")
) {
logError(s"xxx Not found group id column at index $groupIdIndex")
logDebug(s"xxx Not found group id column at index $groupIdIndex")
return false
}
expand.projections.forall {
Expand Down Expand Up @@ -280,35 +271,26 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
attributesToReplace.getOrElse(toReplace, toReplace)
}

def buildReplaceAttributeMapForAggregate(
originalGroupingExpressions: Seq[NamedExpression],
originalExpandProjections: Seq[Seq[Expression]],
originalExpandOutput: Seq[Attribute]): Map[Attribute, Attribute] = {

def buildReplaceAttributeMap(expand: ExpandExecTransformer): Map[Attribute, Attribute] = {
var fullExpandProjection = Seq[Expression]()
for (i <- 0 until originalExpandProjections(0).length) {
val attr = originalExpandProjections.find(x => x(i).isInstanceOf[Attribute]) match {
for (i <- 0 until expand.projections(0).length) {
val attr = expand.projections.find(x => x(i).isInstanceOf[Attribute]) match {
case Some(projection) => projection(i).asInstanceOf[Attribute]
case None => null
}
fullExpandProjection = fullExpandProjection :+ attr
}

var attributeMap = Map[Attribute, Attribute]()
val groupIdAttribute = originalExpandOutput(originalExpandOutput.length - 1)
originalGroupingExpressions.filter(_.toAttribute != groupIdAttribute).foreach {
e =>
val index = originalExpandOutput.indexWhere(_.semanticEquals(e.toAttribute))
val attr = fullExpandProjection(index).asInstanceOf[Attribute]
// if the grouping key is a literal, cast it to Attribute will be null
if (attr != null) {
attributeMap += (e.toAttribute -> attr)
}
for (i <- 0 until expand.output.length) {
if (fullExpandProjection(i).isInstanceOf[Attribute]) {
attributeMap += (expand.output(i) -> fullExpandProjection(i).asInstanceOf[Attribute])
}
}
attributeMap
}

def buildNewExpandProjections(
originalGroupingExpressions: Seq[NamedExpression],
originalExpandProjections: Seq[Seq[Expression]],
originalExpandOutput: Seq[Attribute],
newExpandOutput: Seq[Attribute]): Seq[Seq[Expression]] = {
Expand All @@ -333,8 +315,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
partialAggregate: CHHashAggregateExecTransformer,
expand: ExpandExecTransformer,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
val expandOutput = expand.output
val groupIdAttribute = expandOutput(expandOutput.length - 1)
val groupIdAttribute = expand.output(findGroupingIdIndex(expand))

// if the grouping keys contains literal, they should not be in attributesToReplace
// And we need to remove them from the grouping keys
Expand All @@ -345,7 +326,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute))
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
logError(
logDebug(
s"xxx newGroupingExpresion: $newGroupingExpresion,\n" +
s"groupingExpressions: $groupingExpressions")

Expand All @@ -361,7 +342,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
}
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
logError(
logDebug(
s"xxx newResultExpressions: $newResultExpressions\n" +
s"resultExpressions:$resultExpressions")
partialAggregate.copy(
Expand All @@ -377,18 +358,15 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
// The output of the native plan is not completely consistent with Spark.
val aggregateOutput = partialAggregate.output
val newExpandProjectionTemplate = aggregateOutput
// aggregateOutput.map(e => getReplaceAttribute(e, attributesToReplace))
logError(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}")
logError(s"xxx newExpandProjectionTemplate: $newExpandProjectionTemplate")
logDebug(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}")
logDebug(s"xxx aggregateOutput: $aggregateOutput")

val newExpandProjections = buildNewExpandProjections(
partialAggregate.groupingExpressions,
expand.projections,
expand.output,
newExpandProjectionTemplate
aggregateOutput
)
logError(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}")
logDebug(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}")
ExpandExecTransformer(newExpandProjections, aggregateOutput, child)
}

Expand Down

0 comments on commit b152a26

Please sign in to comment.