Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Remove unnecessary case match in getAttrForAggregateExpr #3629

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ package io.glutenproject.execution.extension
import io.glutenproject.expression._
import io.glutenproject.extension.ExpressionExtensionTrait

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._

import scala.collection.mutable.ListBuffer

case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

lazy val expressionSigs = Seq(
Expand All @@ -32,41 +29,4 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = expressionSigs

/** Get the attribute index of the extension aggregate functions. */
override def getAttrsForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
var reIndex = resIndex
aggregateFunc match {
case CustomSum(_, _) =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1))
aggregateAttr += isEmptyAttr
reIndex += 2
reIndex
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
reIndex += 1
reIndex
}
case Final =>
aggregateAttr += aggregateAttributeList(reIndex)
reIndex += 1
reIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import com.google.protobuf.Any
import java.util

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

case class HashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand All @@ -56,34 +55,18 @@ case class HashAggregateExecTransformer(
resultExpressions,
child) {

override protected def getAttrForAggregateExpr(
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
index: Int): Int = {
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case hllAdapter: HLLAdapter =>
override protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
aggFunc match {
case _: HLLAdapter =>
mode match {
case Partial =>
val aggBufferAttr = hllAdapter.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
case Partial | Final => true
case _ => false
}
case _ =>
resIndex = super.getAttrForAggregateExpr(exp, aggregateAttributeList, aggregateAttr, index)
super.checkAggFuncModeSupport(aggFunc, mode)
}
resIndex
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.sketch.BloomFilter

import com.google.protobuf.Any

Expand Down Expand Up @@ -417,195 +416,47 @@ abstract class HashAggregateExecBaseTransformer(
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case extendedAggFunc
if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
extendedAggFunc.getClass) =>
// get attributes from the extended aggregate functions
ExpressionMappings.expressionExtensionTransformer
.getAttrsForExtensionAggregateExpr(
aggregateFunc,
mode,
exp,
aggregateAttributeList,
aggregateAttr,
index)
Comment on lines -422 to -432
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can not delete this part for the extension, when there are some custom agg functions which have the specified logic to implement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that CustomSum should also follow this pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CustomSum is only a simple custom ut case, in out internal product, we need to implement some custom logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation. I see. Sorry for removing that portion of the logic.

case _: Average | _: First | _: Last =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += 2
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case Sum(_, _) =>
mode match {
case Partial | PartialMerge =>
val sum = aggregateFunc.asInstanceOf[Sum]
val aggBufferAttr = sum.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1))
aggregateAttr += isEmptyAttr
resIndex += 2
resIndex
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
}
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case Count(_) =>
mode match {
case Partial | PartialMerge =>
val count = aggregateFunc.asInstanceOf[Count]
val aggBufferAttr = count.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: Max | _: Min | _: BitAndAgg | _: BitOrAgg | _: BitXorAgg =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == 1,
s"Aggregate function $aggregateFunc expects one buffer attribute.")
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
resIndex += 1
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: Corr =>
mode match {
case Partial | PartialMerge =>
val expectedBufferSize = 6
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == expectedBufferSize,
s"Aggregate function $aggregateFunc" +
s" expects $expectedBufferSize buffer attribute.")
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += expectedBufferSize
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
case _: CovPopulation | _: CovSample =>
mode match {
case Partial | PartialMerge =>
val expectedBufferSize = 4
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
assert(
aggBufferAttr.size == expectedBufferSize,
s"Aggregate function $aggregateFunc" +
s" expects $expectedBufferSize buffer attributes.")
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += expectedBufferSize
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}

protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
aggFunc match {
case _: CollectList | _: CollectSet =>
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += 3
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | Final => true
case _ => false
}
case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") =>
// for spark33
mode match {
case Partial =>
val bloom = aggregateFunc.asInstanceOf[TypedImperativeAggregate[BloomFilter]]
val aggBufferAttr = bloom.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | Final => true
case _ => false
}
case _: CollectList | _: CollectSet =>
case _ =>
mode match {
case Partial =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
case Partial | PartialMerge | Final => true
case _ => false
}
case other =>
throw new UnsupportedOperationException(
s"Unsupported aggregate function in getAttrForAggregateExpr")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode}

import scala.collection.mutable.ListBuffer

trait ExpressionExtensionTrait {

Expand All @@ -40,18 +37,6 @@ trait ExpressionExtensionTrait {
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.")
}

/** Get the attribute index of the extension aggregate functions. */
def getAttrsForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
Expand Down
Loading