Skip to content

Commit

Permalink
[CORE] Remove unnecessary case match in getAttrForAggregateExpr (#3629)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Nov 7, 2023
1 parent 40ef132 commit 7308fdb
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ package io.glutenproject.execution.extension
import io.glutenproject.expression._
import io.glutenproject.extension.ExpressionExtensionTrait

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._

import scala.collection.mutable.ListBuffer

case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

lazy val expressionSigs = Seq(
Expand All @@ -32,41 +29,4 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = expressionSigs

/** Get the attribute index of the extension aggregate functions. */
override def getAttrsForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
var reIndex = resIndex
aggregateFunc match {
case CustomSum(_, _) =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1))
aggregateAttr += isEmptyAttr
reIndex += 2
reIndex
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
reIndex += 1
reIndex
}
case Final =>
aggregateAttr += aggregateAttributeList(reIndex)
reIndex += 1
reIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import com.google.protobuf.Any
import java.util

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

case class HashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand All @@ -56,34 +55,18 @@ case class HashAggregateExecTransformer(
resultExpressions,
child) {

override protected def getAttrForAggregateExpr(
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
index: Int): Int = {
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case hllAdapter: HLLAdapter =>
override protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
aggFunc match {
case _: HLLAdapter =>
mode match {
case Partial =>
val aggBufferAttr = hllAdapter.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
case Partial | Final => true
case _ => false
}
case _ =>
resIndex = super.getAttrForAggregateExpr(exp, aggregateAttributeList, aggregateAttr, index)
super.checkAggFuncModeSupport(aggFunc, mode)
}
resIndex
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.sketch.BloomFilter

import com.google.protobuf.Any

Expand Down Expand Up @@ -417,195 +416,47 @@ abstract class HashAggregateExecBaseTransformer(
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case extendedAggFunc
if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
extendedAggFunc.getClass) =>
// get attributes from the extended aggregate functions
ExpressionMappings.expressionExtensionTransformer
.getAttrsForExtensionAggregateExpr(
aggregateFunc,
mode,
exp,
aggregateAttributeList,
aggregateAttr,
index)
case _: Average | _: First | _: Last =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += 2
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case Sum(_, _) =>
mode match {
case Partial | PartialMerge =>
val sum = aggregateFunc.asInstanceOf[Sum]
val aggBufferAttr = sum.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1))
aggregateAttr += isEmptyAttr
resIndex += 2
resIndex
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
}
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case Count(_) =>
mode match {
case Partial | PartialMerge =>
val count = aggregateFunc.asInstanceOf[Count]
val aggBufferAttr = count.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: Max | _: Min | _: BitAndAgg | _: BitOrAgg | _: BitXorAgg =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == 1,
s"Aggregate function $aggregateFunc expects one buffer attribute.")
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: Corr =>
mode match {
case Partial | PartialMerge =>
val expectedBufferSize = 6
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == expectedBufferSize,
s"Aggregate function $aggregateFunc" +
s" expects $expectedBufferSize buffer attribute.")
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += expectedBufferSize
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: CovPopulation | _: CovSample =>
mode match {
case Partial | PartialMerge =>
val expectedBufferSize = 4
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == expectedBufferSize,
s"Aggregate function $aggregateFunc" +
s" expects $expectedBufferSize buffer attributes.")
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += expectedBufferSize
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}

protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
aggFunc match {
case _: CollectList | _: CollectSet =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += 3
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | Final => true
case _ => false
}
case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") =>
// for spark33
mode match {
case Partial =>
val bloom = aggregateFunc.asInstanceOf[TypedImperativeAggregate[BloomFilter]]
val aggBufferAttr = bloom.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | Final => true
case _ => false
}
case _: CollectList | _: CollectSet =>
case _ =>
mode match {
case Partial =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | PartialMerge | Final => true
case _ => false
}
case other =>
throw new UnsupportedOperationException(
s"Unsupported aggregate function in getAttrForAggregateExpr")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode}

import scala.collection.mutable.ListBuffer

trait ExpressionExtensionTrait {

Expand All @@ -40,18 +37,6 @@ trait ExpressionExtensionTrait {
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.")
}

/** Get the attribute index of the extension aggregate functions. */
def getAttrsForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
Expand Down

0 comments on commit 7308fdb

Please sign in to comment.