diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxStringFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxStringFunctionsSuite.scala index 252bc5fed1ca..ec1a203f34e8 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxStringFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxStringFunctionsSuite.scala @@ -27,6 +27,8 @@ class VeloxStringFunctionsSuite extends WholeStageTransformerSuite { override protected val resourcePath: String = "/tpch-data-parquet-velox" override protected val fileFormat: String = "parquet" + final val LENGTH = 1000 + override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() @@ -43,107 +45,240 @@ class VeloxStringFunctionsSuite extends WholeStageTransformerSuite { .set("spark.sql.sources.useV1SourceList", "avro") } - def checkLengthAndPlan(df: DataFrame) { - this.checkLengthAndPlan(df, 5) - } - test("ascii") { - runQueryAndCompare("select l_orderkey, ascii(l_comment) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, ascii(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } - runQueryAndCompare("select l_orderkey, ascii(null) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, ascii(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("concat") { - runQueryAndCompare("select l_orderkey, concat(l_comment, 'hello') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, concat(l_comment, 'hello', 'world') " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, concat(l_comment, 'hello') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, concat(l_comment, 'hello', 'world') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("instr") { - runQueryAndCompare("select l_orderkey, instr(l_comment, 'h') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, instr(l_comment, null) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, instr(null, 'h') " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, instr(l_comment, 'h') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, instr(l_comment, null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, instr(null, 'h') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("length") { - runQueryAndCompare("select l_orderkey, length(l_comment) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, length(null) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, length(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, length(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + + runQueryAndCompare(s"select l_orderkey, CHAR_LENGTH(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, CHAR_LENGTH(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + + runQueryAndCompare(s"select l_orderkey, CHARACTER_LENGTH(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, CHARACTER_LENGTH(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("lower") { - runQueryAndCompare("select l_orderkey, lower(l_comment) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, lower(null) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, lower(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lower(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("upper") { - runQueryAndCompare("select l_orderkey, upper(l_comment) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, upper(null) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, upper(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, upper(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("lcase") { + runQueryAndCompare(s"select l_orderkey, lcase(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lcase(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("ucase") { + runQueryAndCompare(s"select l_orderkey, ucase(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, ucase(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + ignore("locate") { + runQueryAndCompare(s"select l_orderkey, locate(l_comment, 'a', 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, locate(null, 'a', 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("ltrim/rtrim") { + runQueryAndCompare(s"select l_orderkey, ltrim('SparkSQL ', 'Spark') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, ltrim(' SparkSQL ', 'Spark') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, ltrim(' SparkSQL ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, ltrim(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, ltrim(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + + runQueryAndCompare(s"select l_orderkey, rtrim(' SparkSQL ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rtrim(null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rtrim(l_comment) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("lpad") { + runQueryAndCompare(s"select l_orderkey, lpad(null, 80) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lpad(l_comment, 80) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lpad(l_comment, 80, '??') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lpad(l_comment, null, '??') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, lpad(l_comment, 80, null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("rpad") { + runQueryAndCompare(s"select l_orderkey, rpad(null, 80) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rpad(l_comment, 80) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rpad(l_comment, 80, '??') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rpad(l_comment, null, '??') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rpad(l_comment, 80, null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("like") { - runQueryAndCompare("select l_orderkey, like(l_comment, '%a%') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, like(l_comment, ' ') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, like(null, '%a%') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, l_comment " + - "from lineitem where l_comment like '%a%' limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, like(l_comment, ' ') " + - "from lineitem where l_comment like '' limit 5") { _ => } - runQueryAndCompare("select l_orderkey, like(null, '%a%') " + - "from lineitem where l_comment like '%$$$##@@#&&' limit 5") { _ => } + runQueryAndCompare("""select l_orderkey, like(l_comment, '%\%') """ + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like(l_comment, 'a_%b') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like('l_comment', 'a\\__b') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like(l_comment, 'abc_') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like(l_comment, ' ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like(null, '%a%') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, l_comment " + + s"from lineitem where l_comment like '%a%' limit $LENGTH") { + checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, like(l_comment, ' ') " + + s"from lineitem where l_comment like '' limit $LENGTH") { _ => } + runQueryAndCompare(s"select l_orderkey, like(null, '%a%') " + + s"from lineitem where l_comment like '%$$##@@#&&' limit $LENGTH") { _ => } } test("rlike") { - runQueryAndCompare("select l_orderkey, l_comment, rlike(l_comment, 'a*') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, rlike(l_comment, ' ') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, rlike(null, '%a%') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, l_comment " + - "from lineitem where l_comment rlike '%a%' limit 5") { _ => } - runQueryAndCompare("select l_orderkey, like(l_comment, ' ') " + - "from lineitem where l_comment rlike '' limit 5") { _ => } - runQueryAndCompare("select l_orderkey, like(null, '%a%') " + - "from lineitem where l_comment rlike '%$$$##@@#&&' limit 5") { _ => } + runQueryAndCompare(s"select l_orderkey, l_comment, rlike(l_comment, 'a*') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rlike(l_comment, ' ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, rlike(null, '%a%') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, l_comment " + + s"from lineitem where l_comment rlike '%a%' limit $LENGTH") { _ => } + runQueryAndCompare(s"select l_orderkey, like(l_comment, ' ') " + + s"from lineitem where l_comment rlike '' limit $LENGTH") { _ => } + runQueryAndCompare(s"select l_orderkey, like(null, '%a%') " + + s"from lineitem where l_comment rlike '%$$##@@#&&' limit $LENGTH") { _ => } + } + + test("regexp") { + runQueryAndCompare(s"select l_orderkey, l_comment, regexp(l_comment, 'a*') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, regexp(l_comment, ' ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, regexp(null, '%a%') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, l_comment " + + s"from lineitem where l_comment regexp '%a%' limit $LENGTH") { _ => } + runQueryAndCompare(s"select l_orderkey, l_comment " + + s"from lineitem where l_comment regexp '' limit $LENGTH") { _ => } + runQueryAndCompare(s"select l_orderkey, l_comment " + + s"from lineitem where l_comment regexp '%$$##@@#&&' limit $LENGTH") { _ => } + } + + test("regexp_like") { + runQueryAndCompare(s"select l_orderkey, l_comment, regexp_like(l_comment, 'a*') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, regexp_like(l_comment, ' ') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, regexp_like(null, '%a%') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("regexp_extract") { - runQueryAndCompare("select l_orderkey, regexp_extract(l_comment, '([a-z])', 1) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, regexp_extract(null, '([a-z])', 1) " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, regexp_extract(l_comment, '([a-z])', 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, regexp_extract(null, '([a-z])', 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("replace") { - runQueryAndCompare("select l_orderkey, replace(l_comment, ' ', 'hello') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, replace(l_comment, 'ha') " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, replace(l_comment, ' ', null) " + - "from lineitem limit 5") { checkLengthAndPlan } - runQueryAndCompare("select l_orderkey, replace(l_comment, null, 'hello') " + - "from lineitem limit 5") { checkLengthAndPlan } + runQueryAndCompare(s"select l_orderkey, replace(l_comment, ' ', 'hello') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, replace(l_comment, 'ha') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, replace(l_comment, ' ', null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, replace(l_comment, null, 'hello') " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } } test("split") { - val df = runQueryAndCompare("select l_orderkey, split(l_comment, 'h', 3) " + - "from lineitem limit 5") { _ => } - assert(df.collect().length == 5) + val df = runQueryAndCompare(s"select l_orderkey, split(l_comment, 'h', 3) " + + s"from lineitem limit $LENGTH") { _ => } + assert(df.collect().length == LENGTH) } + test("substr") { + runQueryAndCompare(s"select l_orderkey, substr(l_comment, 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substr(l_comment, 1, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substr(null, 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substr(null, 1, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substr(l_comment, null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substr(l_comment, null, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } + + test("substring") { + runQueryAndCompare(s"select l_orderkey, substring(l_comment, 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substring(l_comment, 1, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substring(null, 1) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substring(null, 1, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substring(l_comment, null) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + runQueryAndCompare(s"select l_orderkey, substring(l_comment, null, 3) " + + s"from lineitem limit $LENGTH") { checkOperatorMatch[ProjectExecTransformer] } + } } diff --git a/cpp/velox/compute/ArrowTypeUtils.cc b/cpp/velox/compute/ArrowTypeUtils.cc index 47b62bdbc154..0484b518e673 100644 --- a/cpp/velox/compute/ArrowTypeUtils.cc +++ b/cpp/velox/compute/ArrowTypeUtils.cc @@ -36,12 +36,18 @@ std::shared_ptr toArrowTypeFromName( if (type_name == "BIGINT") { return arrow::int64(); } + if (type_name == "REAL") { + return arrow::float32(); + } if (type_name == "DOUBLE") { return arrow::float64(); } if (type_name == "VARCHAR") { return arrow::utf8(); } + if (type_name == "VARBINARY") { + return arrow::utf8(); + } // The type name of Array type is like ARRAY. std::string arrayType = "ARRAY"; if (type_name.substr(0, arrayType.length()) == arrayType) { @@ -63,10 +69,14 @@ std::shared_ptr toArrowType(const TypePtr& type) { return arrow::int32(); case TypeKind::BIGINT: return arrow::int64(); + case TypeKind::REAL: + return arrow::float32(); case TypeKind::DOUBLE: return arrow::float64(); case TypeKind::VARCHAR: return arrow::utf8(); + case TypeKind::VARBINARY: + return arrow::utf8(); case TypeKind::TIMESTAMP: return arrow::timestamp(arrow::TimeUnit::MICRO); default: diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala index c4a0068752c4..6fb7eb1ede92 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala @@ -444,8 +444,11 @@ object ConverterUtils extends Logging { final val LENGTH = "char_length" // length final val LOWER = "lower" final val UPPER = "upper" + final val LOCATE = "locate" final val LTRIM = "ltrim" final val RTRIM = "rtrim" + final val LPAD = "lpad" + final val RPAD = "rpad" final val REPLACE = "replace" final val SPLIT = "split" final val STARTS_WITH = "starts_with" diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index 060d9cdb7012..357f4e8ef79b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -148,56 +148,17 @@ object ExpressionConverter extends Logging { attributeSeq), i.hset, expr) - case ss: StringReplace => + case t: TernaryExpression => logInfo(s"${expr.getClass} ${expr} is supported.") - TernaryOperatorTransformer.create( + TernaryExpressionTransformer.create( replaceWithExpressionTransformer( - ss.srcExpr, + t.first, attributeSeq), replaceWithExpressionTransformer( - ss.searchExpr, + t.second, attributeSeq), replaceWithExpressionTransformer( - ss.replaceExpr, - attributeSeq), - expr) - case ss: StringSplit => - logInfo(s"${expr.getClass} ${expr} is supported.") - TernaryOperatorTransformer.create( - replaceWithExpressionTransformer( - ss.str, - attributeSeq), - replaceWithExpressionTransformer( - ss.regex, - attributeSeq), - replaceWithExpressionTransformer( - ss.limit, - attributeSeq), - expr) - case ss: RegExpExtract => - logInfo(s"${expr.getClass} ${expr} is supported.") - TernaryOperatorTransformer.create( - replaceWithExpressionTransformer( - ss.subject, - attributeSeq), - replaceWithExpressionTransformer( - ss.regexp, - attributeSeq), - replaceWithExpressionTransformer( - ss.idx, - attributeSeq), - expr) - case ss: Substring => - logInfo(s"${expr.getClass} ${expr} is supported.") - TernaryOperatorTransformer.create( - replaceWithExpressionTransformer( - ss.str, - attributeSeq), - replaceWithExpressionTransformer( - ss.pos, - attributeSeq), - replaceWithExpressionTransformer( - ss.len, + t.third, attributeSeq), expr) case u: UnaryExpression => diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/TernaryExpressionTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/TernaryExpressionTransformer.scala new file mode 100644 index 000000000000..b29b0599daeb --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/expression/TernaryExpressionTransformer.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.expression + +import com.google.common.collect.Lists +import io.glutenproject.expression.ConverterUtils.FunctionConfig +import io.glutenproject.substrait.`type`.TypeBuilder +import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ + +class LocateTransformer(first: Expression, second: Expression, + third: Expression, original: Expression) + extends StringLocate(first: Expression, second: Expression, second: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val firstNode = + first.asInstanceOf[ExpressionTransformer].doTransform(args) + val secondNode = + second.asInstanceOf[ExpressionTransformer].doTransform(args) + val thirdNode = + third.asInstanceOf[ExpressionTransformer].doTransform(args) + if (!firstNode.isInstanceOf[ExpressionNode] || + !secondNode.isInstanceOf[ExpressionNode] || + !thirdNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"Not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName(ConverterUtils.LOCATE, + Seq(first.dataType, second.dataType, third.dataType), FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode) + val typeNode = TypeBuilder.makeI64(original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +class LPadTransformer(first: Expression, second: Expression, + third: Expression, original: Expression) + extends StringLPad(first: Expression, second: Expression, second: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val firstNode = + first.asInstanceOf[ExpressionTransformer].doTransform(args) + val secondNode = + second.asInstanceOf[ExpressionTransformer].doTransform(args) + val thirdNode = + third.asInstanceOf[ExpressionTransformer].doTransform(args) + if (!firstNode.isInstanceOf[ExpressionNode] || + !secondNode.isInstanceOf[ExpressionNode] || + !thirdNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"Not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName(ConverterUtils.LPAD, + Seq(first.dataType, second.dataType, third.dataType), + FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode) + val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +class RPadTransformer(first: Expression, second: Expression, + third: Expression, original: Expression) + extends StringRPad(first: Expression, second: Expression, second: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val firstNode = + first.asInstanceOf[ExpressionTransformer].doTransform(args) + val secondNode = + second.asInstanceOf[ExpressionTransformer].doTransform(args) + val thirdNode = + third.asInstanceOf[ExpressionTransformer].doTransform(args) + if (!firstNode.isInstanceOf[ExpressionNode] || + !secondNode.isInstanceOf[ExpressionNode] || + !thirdNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"Not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName(ConverterUtils.RPAD, + Seq(first.dataType, second.dataType, third.dataType), + FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode) + val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +class RegExpExtractTransformer(subject: Expression, regexp: Expression, + index: Expression, original: Expression) + extends RegExpExtract(subject: Expression, regexp: Expression, regexp: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val firstNode = + subject.asInstanceOf[ExpressionTransformer].doTransform(args) + val secondNode = + regexp.asInstanceOf[ExpressionTransformer].doTransform(args) + val thirdNode = + index.asInstanceOf[ExpressionTransformer].doTransform(args) + if (!firstNode.isInstanceOf[ExpressionNode] || + !secondNode.isInstanceOf[ExpressionNode] || + !thirdNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"Not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName(ConverterUtils.REGEXP_EXTRACT, + Seq(subject.dataType, regexp.dataType, index.dataType), FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode) + val typeNode = TypeBuilder.makeString(original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +class ReplaceTransformer(str: Expression, search: Expression, replace: Expression, + original: Expression) + extends StringReplace(str: Expression, search: Expression, search: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val strNode = + str.asInstanceOf[ExpressionTransformer].doTransform(args) + val searchNode = + search.asInstanceOf[ExpressionTransformer].doTransform(args) + val replaceNode = + replace.asInstanceOf[ExpressionTransformer].doTransform(args) + if (!strNode.isInstanceOf[ExpressionNode] || + !searchNode.isInstanceOf[ExpressionNode] || + !replaceNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"Not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName(ConverterUtils.REPLACE, + Seq(str.dataType, search.dataType, replace.dataType), FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList(strNode, searchNode, replaceNode) + val typeNode = TypeBuilder.makeString(original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +class SplitTransformer(str: Expression, delimiter: Expression, limit: Expression, + original: Expression) + extends StringSplit(str: Expression, delimiter: Expression, limit: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + throw new UnsupportedOperationException("Not supported: Split.") + } +} + +class SubStringTransformer(str: Expression, pos: Expression, len: Expression, original: Expression) + extends Substring(str: Expression, pos: Expression, len: Expression) + with ExpressionTransformer + with Logging { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val strNode = + str.asInstanceOf[ExpressionTransformer].doTransform(args) + val posNode = + pos.asInstanceOf[ExpressionTransformer].doTransform(args) + val lenNode = + len.asInstanceOf[ExpressionTransformer].doTransform(args) + + if (!strNode.isInstanceOf[ExpressionNode] || + !posNode.isInstanceOf[ExpressionNode] || + !lenNode.isInstanceOf[ExpressionNode]) { + throw new UnsupportedOperationException(s"not supported yet.") + } + + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = ConverterUtils.makeFuncName( + ConverterUtils.SUBSTRING, Seq(str.dataType), FunctionConfig.OPT) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressionNodes = Lists.newArrayList( + strNode.asInstanceOf[ExpressionNode], + posNode.asInstanceOf[ExpressionNode], + lenNode.asInstanceOf[ExpressionNode]) + // Substring inherits NullIntolerant, the output is nullable + val typeNode = TypeBuilder.makeString(true) + + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + +object TernaryExpressionTransformer { + + def create(first: Expression, second: Expression, third: Expression, + original: Expression): Expression = original match { + case _: StringLocate => + // locate() gets incorrect results, so fall back to Vanilla Spark + throw new UnsupportedOperationException("Not supported: locate().") + case _: StringSplit => + // split() gets incorrect results, so fall back to Vanilla Spark + throw new UnsupportedOperationException("Not supported: locate().") + case lpad: StringLPad => + new LPadTransformer(first, second, third, lpad) + case rpad: StringRPad => + new RPadTransformer(first, second, third, rpad) + case extract: RegExpExtract => + new RegExpExtractTransformer(first, second, third, extract) + case replace: StringReplace => + new ReplaceTransformer(first, second, third, replace) + case ss: Substring => + new SubStringTransformer(first, second, third, ss) + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } +} diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/TernaryOperatorTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/TernaryOperatorTransformer.scala index 1db7aca89521..e69de29bb2d1 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/TernaryOperatorTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/TernaryOperatorTransformer.scala @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.glutenproject.expression - -import com.google.common.collect.Lists -import io.glutenproject.expression.ConverterUtils.FunctionConfig -import io.glutenproject.substrait.`type`.TypeBuilder -import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions._ - -class RegExpExtractTransformer(subject: Expression, regexp: Expression, - index: Expression, original: Expression) - extends RegExpExtract(subject: Expression, regexp: Expression, regexp: Expression) - with ExpressionTransformer - with Logging { - - override def doTransform(args: java.lang.Object): ExpressionNode = { - val firstNode = - subject.asInstanceOf[ExpressionTransformer].doTransform(args) - val secondNode = - regexp.asInstanceOf[ExpressionTransformer].doTransform(args) - val thirdNode = - index.asInstanceOf[ExpressionTransformer].doTransform(args) - if (!firstNode.isInstanceOf[ExpressionNode] || - !secondNode.isInstanceOf[ExpressionNode] || - !thirdNode.isInstanceOf[ExpressionNode]) { - throw new UnsupportedOperationException(s"Not supported yet.") - } - - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionName = ConverterUtils.makeFuncName(ConverterUtils.REGEXP_EXTRACT, - Seq(subject.dataType, regexp.dataType, index.dataType), FunctionConfig.OPT) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) - val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode) - val typeNode = TypeBuilder.makeString(original.nullable) - ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) - } -} - -class ReplaceTransformer(str: Expression, search: Expression, replace: Expression, - original: Expression) - extends StringReplace(str: Expression, search: Expression, search: Expression) - with ExpressionTransformer - with Logging { - - override def doTransform(args: java.lang.Object): ExpressionNode = { - val strNode = - str.asInstanceOf[ExpressionTransformer].doTransform(args) - val searchNode = - search.asInstanceOf[ExpressionTransformer].doTransform(args) - val replaceNode = - replace.asInstanceOf[ExpressionTransformer].doTransform(args) - if (!strNode.isInstanceOf[ExpressionNode] || - !searchNode.isInstanceOf[ExpressionNode] || - !replaceNode.isInstanceOf[ExpressionNode]) { - throw new UnsupportedOperationException(s"Not supported yet.") - } - - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionName = ConverterUtils.makeFuncName(ConverterUtils.REPLACE, - Seq(str.dataType, search.dataType, replace.dataType), FunctionConfig.OPT) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) - val expressionNodes = Lists.newArrayList(strNode, searchNode, replaceNode) - val typeNode = TypeBuilder.makeString(original.nullable) - ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) - } -} - -class SplitTransformer(str: Expression, delimiter: Expression, limit: Expression, - original: Expression) - extends StringSplit(str: Expression, delimiter: Expression, limit: Expression) - with ExpressionTransformer - with Logging { - - override def doTransform(args: java.lang.Object): ExpressionNode = { - throw new UnsupportedOperationException("Not supported: Split.") - } -} - -class SubStringTransformer(str: Expression, pos: Expression, len: Expression, original: Expression) - extends Substring(str: Expression, pos: Expression, len: Expression) - with ExpressionTransformer - with Logging { - - override def doTransform(args: java.lang.Object): ExpressionNode = { - val strNode = - str.asInstanceOf[ExpressionTransformer].doTransform(args) - val posNode = - pos.asInstanceOf[ExpressionTransformer].doTransform(args) - val lenNode = - len.asInstanceOf[ExpressionTransformer].doTransform(args) - - if (!strNode.isInstanceOf[ExpressionNode] || - !posNode.isInstanceOf[ExpressionNode] || - !lenNode.isInstanceOf[ExpressionNode]) { - throw new UnsupportedOperationException(s"not supported yet.") - } - - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionName = ConverterUtils.makeFuncName( - ConverterUtils.SUBSTRING, Seq(str.dataType), FunctionConfig.OPT) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) - val expressionNodes = Lists.newArrayList( - strNode.asInstanceOf[ExpressionNode], - posNode.asInstanceOf[ExpressionNode], - lenNode.asInstanceOf[ExpressionNode]) - // Substring inherits NullIntolerant, the output is nullable - val typeNode = TypeBuilder.makeString(true) - - ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) - } -} - -object TernaryOperatorTransformer { - - def create(str: Expression, pos: Expression, len: Expression, original: Expression): Expression = - original match { - case extract: RegExpExtract => - new RegExpExtractTransformer(str, pos, len, extract) - case split: StringSplit => - new SplitTransformer(str, pos, len, split) - case replace: StringReplace => - new ReplaceTransformer(str, pos, len, replace) - case ss: Substring => - new SubStringTransformer(str, pos, len, ss) - case other => - throw new UnsupportedOperationException(s"not currently supported: $other.") - } -} diff --git a/gluten-ut/src/test/scala/io/glutenproject/utils/velox/VeloxNotSupport.scala b/gluten-ut/src/test/scala/io/glutenproject/utils/velox/VeloxNotSupport.scala index 017d8ffd60cd..ceeae0362a99 100644 --- a/gluten-ut/src/test/scala/io/glutenproject/utils/velox/VeloxNotSupport.scala +++ b/gluten-ut/src/test/scala/io/glutenproject/utils/velox/VeloxNotSupport.scala @@ -19,6 +19,7 @@ package io.glutenproject.utils.velox import io.glutenproject.utils.NotSupport import org.apache.spark.sql.DateFunctionsSuite +import org.apache.spark.sql.StringFunctionsSuite import org.apache.spark.sql.catalyst.expressions._ object VeloxNotSupport extends NotSupport { @@ -28,7 +29,9 @@ object VeloxNotSupport extends NotSupport { override lazy val fullSupportSuiteList: Set[String] = Set( simpleClassName[LiteralExpressionSuite], simpleClassName[IntervalExpressionsSuite], + simpleClassName[DateExpressionsSuite], simpleClassName[DecimalExpressionSuite], - simpleClassName[DateExpressionsSuite] + simpleClassName[StringFunctionsSuite], + simpleClassName[RegexpExpressionsSuite] ) } diff --git a/gluten-ut/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala b/gluten-ut/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala index 0f2e355c157d..b3f1388585e6 100644 --- a/gluten-ut/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala +++ b/gluten-ut/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala @@ -17,20 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper -import org.apache.spark.sql.catalyst.expressions.Lower -import org.apache.spark.sql.catalyst.expressions.Upper class GlutenStringFunctionsSuite extends StringFunctionsSuite with GlutenSQLTestsTrait with ExpressionEvalHelper { - override def whiteTestNameList: Seq[String] = Seq( - "string trim functions" - ) - override def blackTestNameList: Seq[String] = Seq( + override def blackTestNameList: Seq[String] = super.blackTestNameList ++ Seq( + "string / binary length function" ) } diff --git a/gluten-ut/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenRegexpExpressionsSuite.scala b/gluten-ut/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenRegexpExpressionsSuite.scala index a75bc4a200c4..9eeddada56e8 100644 --- a/gluten-ut/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenRegexpExpressionsSuite.scala +++ b/gluten-ut/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenRegexpExpressionsSuite.scala @@ -17,7 +17,23 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.{GlutenTestConstants, GlutenTestsTrait} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.GlutenTestsTrait class GlutenRegexpExpressionsSuite extends RegexpExpressionsSuite with GlutenTestsTrait { + + override protected def checkEvaluation(expression: => Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val resolver = ResolveTimeZone + val expr = resolver.resolveTimeZones(expression) + assert(expr.resolved) + + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + // Consistent with the evaluation approach in vanilla spark UT to avoid overflow issue + // in resultDF.collect() for some corner cases. + glutenCheckExpression(expr, catalystValue, inputRow, justEvalExpr = true) + } + }