Skip to content

Commit

Permalink
[GLUTEN-2031][VL] Enable rand function (#2749)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored Nov 27, 2023
1 parent 5a33545 commit 1f0f6a8
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
left: ExpressionTransformer,
right: ExpressionTransformer,
original: GetMapValue): ExpressionTransformer =
new GetMapValueTransformer(substraitExprName, left, right, original.failOnError, original)
GetMapValueTransformer(substraitExprName, left, right, original.failOnError, original)

override def genRandTransformer(
substraitExprName: String,
explicitSeed: ExpressionTransformer,
original: Rand): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(explicitSeed), original)
}

/**
* Generate ShuffleDependency for ColumnarShuffleExchangeExec.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import io.glutenproject.expression.WindowFunctionsBuilder
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, OrcReadFormat, ParquetReadFormat}

import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Literal, NamedExpression, NthValue, PercentRank, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Literal, NamedExpression, NthValue, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
Expand Down Expand Up @@ -233,48 +233,35 @@ object BackendSettings extends BackendSettingsApi {
}

/**
* Check whether plan is Count(1).
* Check whether a plan needs to be offloaded even though they have empty input schema, e.g,
* Sum(1), Count(1), rand(), etc.
* @param plan:
* The Spark plan to check.
* @return
* Whether plan is an Aggregation of Count(1).
*/
private def isCount1(plan: SparkPlan): Boolean = {
plan match {
case exec: HashAggregateExec
if exec.aggregateExpressions.nonEmpty &&
exec.aggregateExpressions.forall(
expression => {
expression.aggregateFunction match {
case c: Count => c.children.size == 1 && c.children.head.equals(Literal(1))
case _ => false
}
}) =>
true
case _ =>
false
private def mayNeedOffload(plan: SparkPlan): Boolean = {
def checkExpr(expr: Expression): Boolean = {
expr match {
// Block directly falling back the below functions by FallbackEmptySchemaRelation.
case alias: Alias => checkExpr(alias.child)
case _: Rand => true
case _ => false
}
}
}

/**
* Check whether plan is Sum(1).
* @param plan:
* The Spark plan to check.
* @return
* Whether plan is an Aggregation of Sum(1).
*/
private def isSum1(plan: SparkPlan): Boolean = {
plan match {
case exec: HashAggregateExec
if exec.aggregateExpressions.nonEmpty &&
exec.aggregateExpressions.forall(
expression => {
expression.aggregateFunction match {
case s: Sum => s.children.size == 1 && s.children.head.equals(Literal(1))
case _ => false
}
}) =>
true
case exec: HashAggregateExec if exec.aggregateExpressions.nonEmpty =>
// Check Sum(1) or Count(1).
exec.aggregateExpressions.forall(
expression => {
val aggFunction = expression.aggregateFunction
aggFunction match {
case _: Sum | _: Count =>
aggFunction.children.size == 1 && aggFunction.children.head.equals(Literal(1))
case _ => false
}
})
case p: ProjectExec if p.projectList.nonEmpty =>
p.projectList.forall(checkExpr(_))
case _ =>
false
}
Expand All @@ -283,7 +270,7 @@ object BackendSettings extends BackendSettingsApi {
override def fallbackOnEmptySchema(plan: SparkPlan): Boolean = {
// Count(1) and Sum(1) are special cases that Velox backend can handle.
// Do not fallback it and its children in the first place.
!(isCount1(plan) || isSum1(plan))
!mayNeedOffload(plan)
}

override def fallbackAggregateWithChild(): Boolean = true
Expand Down
1 change: 0 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ static const std::unordered_set<std::string> kBlackList = {
"factorial",
"concat_ws",
"from_json",
"rand",
"json_array_length",
"from_unixtime",
"repeat",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(srcExpr, regexExpr, limitExpr), original)
}

def genRandTransformer(
substraitExprName: String,
explicitSeed: ExpressionTransformer,
original: Rand): ExpressionTransformer = {
RandTransformer(substraitExprName, explicitSeed, original)
}

/** Generate an expression transformer to transform GetMapValue to Substrait. */
def genGetMapValueTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ object ExpressionConverter extends SQLConfHelper with Logging {
substraitExprName,
replaceWithExpressionTransformer(m.child, attributeSeq),
m)
case rand: Rand =>
BackendsApiManager.getSparkPlanExecApiInstance.genRandTransformer(
substraitExprName,
replaceWithExpressionTransformer(rand.child, attributeSeq),
rand)
case _: KnownFloatingPointNormalized | _: NormalizeNaNAndZero | _: PromotePrecision =>
ChildTransformer(
replaceWithExpressionTransformer(expr.children.head, attributeSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,31 @@ case class MakeDecimalTransformer(
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

/**
* User can specify a seed for this function. If lacked, spark will generate a random number as
* seed. We also need to pass a unique partitionIndex provided by framework to native library for
* each thread. Then, seed plus partitionIndex will be the actual seed for generator, similar to
* vanilla spark. This is based on the fact that partitioning is deterministic and one partition is
* corresponding to one task thread.
*/
case class RandTransformer(
substraitExprName: String,
explicitSeed: ExpressionTransformer,
original: Rand)
extends ExpressionTransformer {

override def doTransform(args: java.lang.Object): ExpressionNode = {
if (!original.hideSeed) {
// TODO: for user-specified seed, we need to pass partition index to native engine.
throw new UnsupportedOperationException("User-specified seed is not supported.")
}
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionId = ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(substraitExprName, Seq(original.child.dataType)))
val inputNodes = Lists.newArrayList[ExpressionNode]()
val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, inputNodes, typeNode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.glutenproject.utils.clickhouse
import io.glutenproject.utils.{BackendTestSettings, SQLQueryTestSettings}

import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -135,6 +136,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("collect functions should be able to cast to array type with no null values")
.exclude("SPARK-17616: distinct aggregate combined with a non-partial aggregate")
.exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it")
.exclude(GLUTEN_TEST + "SPARK-19471: AggregationIterator does not initialize the generated" +
" result projection before using it")
.exclude("SPARK-26021: NaN and -0.0 in grouping expressions")
.exclude("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate")
.exclude("SPARK-32136: NormalizeFloatingNumbers should work on null struct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class VeloxTestSettings extends BackendTestSettings {
"zero moments", // [velox does not return NaN]
"SPARK-26021: NaN and -0.0 in grouping expressions", // NaN case
// incorrect result, distinct NaN case
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate"
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it"
)

enableSuite[GlutenCastSuite]
Expand Down Expand Up @@ -123,7 +126,10 @@ class VeloxTestSettings extends BackendTestSettings {
// Rewrite this test because the describe functions creates unmatched plan.
"describe",
// Not supported for approx_count_distinct
"SPARK-34165: Add count_distinct to summary"
"SPARK-34165: Add count_distinct to summary",
// Result depends on the implementation for nondeterministic expression rand.
// Not really an issue.
"SPARK-9083: sort with non-deterministic expressions"
)
// Double precision loss: https://github.com/facebookincubator/velox/pull/6051#issuecomment-1731028215.
.exclude("SPARK-22271: mean overflows and returns null for some decimal variables")
Expand Down Expand Up @@ -256,6 +262,9 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
enableSuite[GlutenDataFrameTungstenSuite]
enableSuite[GlutenDataFrameSetOperationsSuite]
// Result depends on the implementation for nondeterministic expression rand.
// Not really an issue.
.exclude("SPARK-10740: handle nondeterministic expressions correctly for set operations")
enableSuite[GlutenDataFrameStatSuite]
enableSuite[GlutenComplexTypesSuite]
// Incorrect result for array and length.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ package org.apache.spark.sql

import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestData.DecimalData

import scala.util.Random

class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenSQLTestsTrait {

import testImplicits._
Expand Down Expand Up @@ -300,4 +305,61 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS
)
}

// Ported from spark DataFrameAggregateSuite only with plan check changed.
private def assertNoExceptions(c: Column): Unit = {
for (
(wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))
) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {

val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")

// test case for HashAggregate
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
hashAggDF.collect()
val hashAggPlan = hashAggDF.queryExecution.executedPlan
if (wholeStage) {
assert(find(hashAggPlan) {
case WholeStageCodegenExec(_: HashAggregateExec) => true
// If offloaded, spark whole stage codegen takes no effect and a gluten hash agg is
// expected to be used.
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
} else {
assert(
stripAQEPlan(hashAggPlan).isInstanceOf[HashAggregateExec] ||
stripAQEPlan(hashAggPlan).find {
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
}

// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
objHashAggOrSortAggDF.collect()
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
}
}
}

test(
GLUTEN_TEST + "SPARK-19471: AggregationIterator does not initialize the generated" +
" result projection before using it") {
Seq(
monotonically_increasing_id(),
spark_partition_id(),
rand(Random.nextLong()),
randn(Random.nextLong())
).foreach(assertNoExceptions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.glutenproject.utils.clickhouse
import io.glutenproject.utils.{BackendTestSettings, SQLQueryTestSettings}

import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors._
Expand Down Expand Up @@ -154,6 +155,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("collect functions should be able to cast to array type with no null values")
.exclude("SPARK-17616: distinct aggregate combined with a non-partial aggregate")
.exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it")
.exclude(GLUTEN_TEST + "SPARK-19471: AggregationIterator does not initialize the generated" +
" result projection before using it")
.exclude("SPARK-26021: NaN and -0.0 in grouping expressions")
.exclude("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate")
.exclude("SPARK-32136: NormalizeFloatingNumbers should work on null struct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,10 @@ class VeloxTestSettings extends BackendTestSettings {
"zero moments", // [velox does not return NaN]
"SPARK-26021: NaN and -0.0 in grouping expressions", // NaN case
// incorrect result, distinct NaN case
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate"
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
Expand All @@ -1014,6 +1017,9 @@ class VeloxTestSettings extends BackendTestSettings {
// exclude as map not supported
.exclude("SPARK-36797: Union should resolve nested columns as top-level columns")
.exclude("SPARK-37371: UnionExec should support columnar if all children support columnar")
// Result depends on the implementation for nondeterministic expression rand.
// Not really an issue.
.exclude("SPARK-10740: handle nondeterministic expressions correctly for set operations")
enableSuite[GlutenDataFrameStatSuite]
enableSuite[GlutenDataFrameSuite]
// Rewrite these tests because it checks Spark's physical operators.
Expand All @@ -1035,7 +1041,10 @@ class VeloxTestSettings extends BackendTestSettings {
// decimal failed ut.
"SPARK-22271: mean overflows and returns null for some decimal variables",
// Not supported for approx_count_distinct
"SPARK-34165: Add count_distinct to summary"
"SPARK-34165: Add count_distinct to summary",
// Result depends on the implementation for nondeterministic expression rand.
// Not really an issue.
"SPARK-9083: sort with non-deterministic expressions"
)
enableSuite[GlutenDataFrameTimeWindowingSuite]
enableSuite[GlutenDataFrameTungstenSuite]
Expand Down
Loading

0 comments on commit 1f0f6a8

Please sign in to comment.