diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 39686f7a3f144..df319c4467b62 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -666,15 +666,16 @@ class UDFHolder { DECLARE_METHOD_RESOLVER(callAscii_method_resolver, callAscii); DECLARE_METHOD_RESOLVER(initialize_method_resolver, initialize); - // Check which flavor of the call() method is provided by the UDF object. UDFs - // are required to provide at least one of the following methods: + // Check which flavor of the call()/callNullable()/callNullFree() method is + // provided by the UDF object. UDFs are required to provide at least one of + // the following methods: // // - bool|void|Status call(...) // - bool|void|Status callNullable(...) // - bool|void|Status callNullFree(...) // - // Each of these methods can return either bool or void or Status. Returning - // void means that the UDF is assumed never to return null values. Returning + // Each of these methods can return bool, void or Status. Returning void + // means that the UDF is assumed never to return null values. Returning // Status to hold success or error outcome of the function call. // // Optionally, UDFs can also provide the following methods: @@ -708,14 +709,9 @@ class UDFHolder { udf_has_call_return_void | udf_has_call_return_status; static_assert( - !(udf_has_call_return_bool && udf_has_call_return_void), - "Provided call() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_call_return_bool && udf_has_call_return_status), - "Provided call() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_call_return_void && udf_has_call_return_status), - "Provided call() methods need to return either void OR bool OR Status."); + udf_has_call_return_void ^ udf_has_call_return_bool ^ + udf_has_call_return_status, + "Provided call() methods need to only return void, bool OR Status."); // callNullable(): static constexpr bool udf_has_callNullable_return_bool = util::has_method< @@ -739,15 +735,11 @@ class UDFHolder { static constexpr bool udf_has_callNullable = udf_has_callNullable_return_bool | udf_has_callNullable_return_void | udf_has_callNullable_return_status; + static_assert( - !(udf_has_callNullable_return_bool && udf_has_callNullable_return_void), - "Provided callNullable() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_callNullable_return_bool && udf_has_callNullable_return_status), - "Provided callNullable() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_callNullable_return_void && udf_has_callNullable_return_status), - "Provided callNullable() methods need to return either void OR bool OR Status."); + udf_has_callNullable_return_void ^ udf_has_callNullable_return_bool ^ + udf_has_callNullable_return_status, + "Provided callNullable() methods need to only return void, bool OR Status."); // callNullFree(): static constexpr bool udf_has_callNullFree_return_bool = util::has_method< @@ -771,15 +763,11 @@ class UDFHolder { static constexpr bool udf_has_callNullFree = udf_has_callNullFree_return_bool | udf_has_callNullFree_return_void | udf_has_callNullFree_return_status; + static_assert( - !(udf_has_callNullFree_return_bool && udf_has_callNullFree_return_void), - "Provided callNullFree() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_callNullFree_return_bool && udf_has_callNullFree_return_status), - "Provided callNullFree() methods need to return either void OR bool OR Status."); - static_assert( - !(udf_has_callNullFree_return_void && udf_has_callNullFree_return_status), - "Provided callNullFree() methods need to return either void OR bool OR Status."); + udf_has_callNullFree_return_void ^ udf_has_callNullFree_return_bool ^ + udf_has_callNullFree_return_status, + "Provided callNullFree() methods need to only return void, bool OR Status."); // callAscii(): static constexpr bool udf_has_callAscii_return_bool = util::has_method< diff --git a/velox/expression/tests/SimpleFunctionTest.cpp b/velox/expression/tests/SimpleFunctionTest.cpp index 471259d70a107..ade4601ea949a 100644 --- a/velox/expression/tests/SimpleFunctionTest.cpp +++ b/velox/expression/tests/SimpleFunctionTest.cpp @@ -1508,13 +1508,6 @@ struct NoThrowFunction { out = in / 6; return Status::OK(); } - - Status callNullable(out_type& out, const arg_type* in) { - if (!in) { - return Status::UserError("Input cannot be NULL"); - } - return Status::OK(); - } }; TEST_F(SimpleFunctionTest, noThrow) { @@ -1547,12 +1540,66 @@ TEST_F(SimpleFunctionTest, noThrow) { VELOX_ASSERT_THROW( (evaluateOnce("try(no_throw(c0))", 6)), "Input must not be 6"); +} + +template +struct CallNullableNoThrowFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + Status callNullable(out_type& out, const arg_type* in) { + if (!in) { + return Status::UserError("Input cannot be NULL"); + } + out = *in + 1; + return Status::OK(); + } +}; +TEST_F(SimpleFunctionTest, callNullableNoThrow) { + registerFunction( + {"nullable_no_throw"}); // Error reported via Status by callNullable. VELOX_ASSERT_THROW( - (evaluateOnce("no_throw(c0)", std::nullopt)), + (evaluateOnce("nullable_no_throw(c0)", std::nullopt)), "Input cannot be NULL"); - result = evaluateOnce("try(no_throw(c0))", std::nullopt); + + result = evaluateOnce( + "try(nullable_no_throw(c0))", std::nullopt); + EXPECT_EQ(std::nullopt, result); +} + +template +struct CallNullFreeNoThrowFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + Status callNullFree(out_type& out, const arg_type& in) { + if (in == 0) { + return Status::UserError("Input cannot be 0"); + } + if (in % 2 == 0) { + return Status::UserError("Input cannot be even"); + } + out = *in + 1; + return Status::OK(); + } +}; + +TEST_F(SimpleFunctionTest, callNullFreeNoThrow) { + registerFunction( + {"null_free_no_throw"}); + // Error reported via Status by callNullable. + VELOX_ASSERT_THROW( + (evaluateOnce("null_free_no_throw(c0)", 0)), + "Input cannot be 0"); + + result = evaluateOnce("try(null_free_no_throw(c0))", 0); + EXPECT_EQ(std::nullopt, result); + + VELOX_ASSERT_THROW( + (evaluateOnce("null_free_no_throw(c0)", 4)), + "Input cannot be even"); + + result = evaluateOnce("try(null_free_no_throw(c0))", 4); EXPECT_EQ(std::nullopt, result); }