From 4c1b9298ae82ff801932e4e9834dea4cb1f70ddf Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Wed, 4 Sep 2024 09:53:50 +0800 Subject: [PATCH] [fix](unary function) Fix wrong result of asin, acos and sqrt when processing invalid input (#40267) When input of asin, acos and sqrt is invalid, result of them should be null (same with mysql). --- .../function_math_unary_alway_nullable.h | 94 ++++++++++++++++++ be/src/vec/functions/math.cpp | 16 +++- .../expressions/functions/scalar/Acos.java | 4 +- .../expressions/functions/scalar/Asin.java | 4 +- .../expressions/functions/scalar/Dsqrt.java | 4 +- .../expressions/functions/scalar/Sqrt.java | 4 +- .../test_math_unary_always_nullable.out | 95 +++++++++++++++++++ .../test_math_unary_always_nullable.groovy | 85 +++++++++++++++++ 8 files changed, 295 insertions(+), 11 deletions(-) create mode 100644 be/src/vec/functions/function_math_unary_alway_nullable.h create mode 100644 regression-test/data/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.out create mode 100644 regression-test/suites/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.groovy diff --git a/be/src/vec/functions/function_math_unary_alway_nullable.h b/be/src/vec/functions/function_math_unary_alway_nullable.h new file mode 100644 index 00000000000000..8d2cea1bc0db87 --- /dev/null +++ b/be/src/vec/functions/function_math_unary_alway_nullable.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "vec/columns/column.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/columns_number.h" +#include "vec/core/call_on_type_index.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/function.h" +#include "vec/functions/function_helpers.h" +#include "vec/utils/util.hpp" + +namespace doris::vectorized { + +template +class FunctionMathUnaryAlwayNullable : public IFunction { +public: + using IFunction::execute; + + static constexpr auto name = Impl::name; + static FunctionPtr create() { return std::make_shared(); } + +private: + String get_name() const override { return name; } + size_t get_number_of_arguments() const override { return 1; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return make_nullable(std::make_shared()); + } + + static void execute_in_iterations(const double* src_data, double* dst_data, size_t size) { + for (size_t i = 0; i < size; i++) { + Impl::execute(&src_data[i], &dst_data[i]); + } + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const ColumnFloat64* col = + assert_cast(block.get_by_position(arguments[0]).column.get()); + auto dst = ColumnFloat64::create(); + auto& dst_data = dst->get_data(); + dst_data.resize(input_rows_count); + + execute_in_iterations(col->get_data().data(), dst_data.data(), input_rows_count); + + auto result_null_map = ColumnUInt8::create(input_rows_count, 0); + + for (size_t i = 0; i < input_rows_count; i++) { + if (Impl::is_invalid_input(col->get_data()[i])) [[unlikely]] { + result_null_map->get_data().data()[i] = 1; + } + } + + block.replace_by_position( + result, ColumnNullable::create(std::move(dst), std::move(result_null_map))); + return Status::OK(); + } +}; + +template +struct UnaryFunctionPlainAlwayNullable { + using Type = DataTypeFloat64; + static constexpr auto name = Name::name; + + static constexpr bool is_invalid_input(Float64 x) { return Name::is_invalid_input(x); } + + template + static void execute(const T* src, U* dst) { + *dst = static_cast(Function(*src)); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index a3b54c8026db75..af2e68ec9822c8 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -37,6 +37,7 @@ #include "vec/functions/function_const.h" #include "vec/functions/function_math_log.h" #include "vec/functions/function_math_unary.h" +#include "vec/functions/function_math_unary_alway_nullable.h" #include "vec/functions/function_string.h" #include "vec/functions/function_totype.h" #include "vec/functions/function_unary_arithmetic.h" @@ -53,13 +54,19 @@ struct Log2Impl; namespace doris::vectorized { struct AcosName { static constexpr auto name = "acos"; + // https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_acos + static constexpr bool is_invalid_input(Float64 x) { return x < -1 || x > 1; } }; -using FunctionAcos = FunctionMathUnary>; +using FunctionAcos = + FunctionMathUnaryAlwayNullable>; struct AsinName { static constexpr auto name = "asin"; + // https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_asin + static constexpr bool is_invalid_input(Float64 x) { return x < -1 || x > 1; } }; -using FunctionAsin = FunctionMathUnary>; +using FunctionAsin = + FunctionMathUnaryAlwayNullable>; struct AtanName { static constexpr auto name = "atan"; @@ -242,8 +249,11 @@ using FunctionSin = FunctionMathUnary; struct SqrtName { static constexpr auto name = "sqrt"; + // https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_sqrt + static constexpr bool is_invalid_input(Float64 x) { return x < 0; } }; -using FunctionSqrt = FunctionMathUnary>; +using FunctionSqrt = + FunctionMathUnaryAlwayNullable>; struct CbrtName { static constexpr auto name = "cbrt"; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Acos.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Acos.java index c99af81123fc63..2193221c326363 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Acos.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Acos.java @@ -19,8 +19,8 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DoubleType; @@ -34,7 +34,7 @@ * ScalarFunction 'acos'. This class is generated by GenerateFunction. */ public class Acos extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Asin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Asin.java index 0e06d8d77edb10..22e1ff59b7df28 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Asin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Asin.java @@ -19,8 +19,8 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DoubleType; @@ -34,7 +34,7 @@ * ScalarFunction 'asin'. This class is generated by GenerateFunction. */ public class Asin extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Dsqrt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Dsqrt.java index 874befd09dba4d..3caef79776b3bb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Dsqrt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Dsqrt.java @@ -19,8 +19,8 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DoubleType; @@ -34,7 +34,7 @@ * ScalarFunction 'dsqrt'. This class is generated by GenerateFunction. */ public class Dsqrt extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sqrt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sqrt.java index 495321c6dfa8c5..f954eb07a54083 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sqrt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sqrt.java @@ -19,8 +19,8 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DoubleType; @@ -34,7 +34,7 @@ * ScalarFunction 'sqrt'. This class is generated by GenerateFunction. */ public class Sqrt extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.out b/regression-test/data/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.out new file mode 100644 index 00000000000000..0a190f0bd6b2f9 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.out @@ -0,0 +1,95 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !acos_1 -- +\N true + +-- !acos_2 -- +\N true + +-- !acos_3 -- +\N true 0 +\N true 1 +\N true 2 +\N true 3 +\N true 4 +\N true 5 +\N true 6 +\N true 7 +\N true 8 +\N true 9 + +-- !asin_1 -- +\N true + +-- !asin_2 -- +\N true + +-- !asin_3 -- +\N true 0 +\N true 1 +\N true 2 +\N true 3 +\N true 4 +\N true 5 +\N true 6 +\N true 7 +\N true 8 +\N true 9 + +-- !sqrt_1 -- +\N true + +-- !sqrt_2 -- +\N true + +-- !sqrt_3 -- +\N true 0 +\N true 1 +\N true 2 +\N true 3 +\N true 4 +\N true 5 +\N true 6 +\N true 7 +\N true 8 +\N true 9 + +-- !acos_tbl_1 -- +1 \N true +2 \N true +3 1.5707963267948966 false +4 \N true +5 \N true +6 \N true +7 \N true +8 \N true + +-- !asin_tbl_1 -- +1 \N true +2 \N true +3 0.0 false +4 \N true +5 \N true +6 \N true +7 \N true +8 \N true + +-- !sqrt_tbl_1 -- +1 1.0488088481701516 false +2 \N true +3 0.0 false +4 \N true +5 \N true +6 \N true +7 \N true +8 \N true + +-- !dsqrt_tbl_1 -- +1 1.0488088481701516 false +2 \N true +3 0.0 false +4 \N true +5 \N true +6 \N true +7 \N true +8 \N true + diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.groovy new file mode 100644 index 00000000000000..282d4e3c5754e4 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_math_unary_always_nullable.groovy @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_math_unary_alway_nullable") { + sql """ + set debug_skip_fold_constant=true; + """ + + qt_acos_1 """ + select acos(1.1), acos(1.1) is null; + """ + qt_acos_2 """ + select acos(-1.1), acos(-1.1) is null; + """ + qt_acos_3 """ + select acos(-1.1), acos(-1.1) is NULL, number from numbers("number"="10") + """ + + qt_asin_1 """ + select asin(1.1), asin(1.1) is null; + """ + qt_asin_2 """ + select asin(-1.1), asin(-1.1) is null; + """ + qt_asin_3 """ + select asin(-1.1), asin(-1.1) is NULL, number from numbers("number"="10") + """ + + qt_sqrt_1 """ + select sqrt(-1), sqrt(-1) is null; + """ + qt_sqrt_2 """ + select sqrt(-1.1), sqrt(-1.1) is null; + """ + qt_sqrt_3 """ + select sqrt(-1.1), sqrt(-1.1) is NULL, number from numbers("number"="10") + """ + + sql "drop table if exists test_math_unary_alway_nullable" + + sql """ + create table if not exists test_math_unary_alway_nullable (rowid int, val double NULL) + distributed by hash(rowid) properties ("replication_num"="1"); + """ + + sql """ + insert into test_math_unary_alway_nullable values + (1, 1.1), (2, -1.1), (3, 0), (4, NULL) + """ + sql """ + insert into test_math_unary_alway_nullable values + (5, NULL), (6, NULL), (7, NULL), (8, NULL) + """ + + qt_acos_tbl_1 """ + select rowid, acos(val), acos(val) is null from test_math_unary_alway_nullable order by rowid; + """ + + qt_asin_tbl_1 """ + select rowid, asin(val), asin(val) is null from test_math_unary_alway_nullable order by rowid; + """ + + qt_sqrt_tbl_1 """ + select rowid, sqrt(val), sqrt(val) is null from test_math_unary_alway_nullable order by rowid; + """ + + qt_dsqrt_tbl_1 """ + select rowid, dsqrt(val), dsqrt(val) is null from test_math_unary_alway_nullable order by rowid; + """ + +} \ No newline at end of file