Skip to content

Commit

Permalink
[GLUTEN-4039][VL] Add array forall and exists function support (#5420)
Browse files Browse the repository at this point in the history
  • Loading branch information
lyy-pineapple authored Apr 22, 2024
1 parent 4fca325 commit 471f382
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, BloomFilterM
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayFilter, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim, TryEval, Uuid, VeloxBloomFilterMightContain}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayExists, ArrayFilter, ArrayForAll, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim, TryEval, Uuid, VeloxBloomFilterMightContain}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter, VeloxBloomFilterAggregate}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -203,6 +203,34 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
}
}

/** Transform array forall to Substrait. */
override def genArrayForAllTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayForAll): ExpressionTransformer = {
expr.function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
throw new GlutenNotSupportException(
"forall on array with lambda using index argument is not supported yet")
case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}
}

/** Transform array exists to Substrait */
override def genArrayExistsTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayExists): ExpressionTransformer = {
expr.function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
throw new GlutenNotSupportException(
"exists on array with lambda using index argument is not supported yet")
case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}
}

/** Transform posexplode to Substrait. */
override def genPosExplodeTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,4 +765,44 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("test array forall") {
withTempPath {
path =>
Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null)
.toDF("value")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")

runQueryAndCompare("select forall(value, x -> x % 2 == 1) as res from array_tbl;") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("select forall(value, x -> x is not null) as res from array_tbl;") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}

test("test array exists") {
withTempPath {
path =>
Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null)
.toDF("value")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")

runQueryAndCompare("select exists(value, x -> x % 2 == 1) as res from array_tbl;") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("select exists(value, x -> x is not null) as res from array_tbl;") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}

}
4 changes: 3 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"try_add", "plus"}};
{"try_add", "plus"},
{"forall", "all_match"},
{"exists", "any_match"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
Expand Down
3 changes: 2 additions & 1 deletion docs/velox-backend-support-progress.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,11 @@ Gluten supports 199 functions. (Drag to right to see all data types)
| arrays_zip | | | | | | | | | | | | | | | | | | | | | | |
| cardinality | cardinality | | | | | | | | | | | | | | | | | | | | | |
| element_at | element_at | element_at | S | | | | | | | | | | | | | | | | S | S | | |
| exists | | | | | | | | | | | | | | | | | | | | | | |
| exists | any_match | | S | | | | | | | | | | | | | | | | | | | |
| explode, explode_outer | | | | | | | | | | | | | | | | | | | | | | |
| explode_outer, explode | | | | | | | | | | | | | | | | | | | | | | |
| filter | filter | filter | S | Lambda with index argument not supported | | | | | | | | | | | | | | | | | | |
| forall | all_match | | S | | | | | | | | | | | | | | | | | | | |
| flatten | flatten | | | | | | | | | | | | | | | | | | | | | |
| map | map | map | S | | | | | | | | | | | | | | | | | | | |
| map_concat | map_concat | | | | | | | | | | | | | | | | | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("filter(on array) is not supported")
}

/** Transform array forall to Substrait. */
def genArrayForAllTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayForAll): ExpressionTransformer = {
throw new GlutenNotSupportException("all_match is not supported")
}

/** Transform array exists to Substrait */
def genArrayExistsTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayExists): ExpressionTransformer = {
throw new GlutenNotSupportException("any_match is not supported")
}

/** Transform inline to Substrait. */
def genInlineTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,22 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformerInternal(tryEval.child, attributeSeq, expressionsMap),
tryEval
)
case a: ArrayForAll =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayForAllTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap),
a
)

case a: ArrayExists =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayExistsTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap),
a
)

case expr =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ object ExpressionMappings {
Sig[ArrayRepeat](ARRAY_REPEAT),
Sig[ArrayRemove](ARRAY_REMOVE),
Sig[ArrayFilter](FILTER),
Sig[ArrayForAll](FORALL),
Sig[ArrayExists](EXISTS),
Sig[Shuffle](SHUFFLE),
// Map functions
Sig[CreateMap](CREATE_MAP),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ object ExpressionNames {
final val ARRAY_REPEAT = "array_repeat"
final val ARRAY_REMOVE = "array_remove"
final val FILTER = "filter"
final val FORALL = "forall"
final val EXISTS = "exists"
final val SHUFFLE = "shuffle"

// Map functions
Expand Down

0 comments on commit 471f382

Please sign in to comment.