From 9628777479a970db5d0c2d0b456dac6633864760 Mon Sep 17 00:00:00 2001 From: PaulXiCao Date: Tue, 23 Jul 2024 15:11:44 +0000 Subject: [PATCH] [libc++][math] Fix undue overflowing of `std::hypot(x,y,z)` (#93350) The 3-dimentionsional `std::hypot(x,y,z)` was sub-optimally implemented. This lead to possible over-/underflows in (intermediate) results which can be circumvented by this proposed change. The idea is to to scale the arguments (see linked issue for full discussion). Tests have been added for problematic over- and underflows. Closes #92782 --- libcxx/include/__math/hypot.h | 89 ++++++++++++++++++ libcxx/include/cmath | 25 +---- .../test/libcxx/transitive_includes/cxx17.csv | 3 + .../test/libcxx/transitive_includes/cxx20.csv | 3 + .../test/libcxx/transitive_includes/cxx23.csv | 3 + .../test/libcxx/transitive_includes/cxx26.csv | 3 + .../test/std/numerics/c.math/cmath.pass.cpp | 91 +++++++++++++++---- libcxx/test/support/fp_compare.h | 45 ++++----- 8 files changed, 197 insertions(+), 65 deletions(-) diff --git a/libcxx/include/__math/hypot.h b/libcxx/include/__math/hypot.h index 1bf193a9ab7ee9..61fd260c594095 100644 --- a/libcxx/include/__math/hypot.h +++ b/libcxx/include/__math/hypot.h @@ -15,10 +15,21 @@ #include <__type_traits/is_same.h> #include <__type_traits/promote.h> +#if _LIBCPP_STD_VER >= 17 +# include <__algorithm/max.h> +# include <__math/abs.h> +# include <__math/roots.h> +# include <__utility/pair.h> +# include +#endif + #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD namespace __math { @@ -41,8 +52,86 @@ inline _LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2>::type hypot(_A1 __x, _ return __math::hypot((__result_type)__x, (__result_type)__y); } +#if _LIBCPP_STD_VER >= 17 +// Factors needed to determine if over-/underflow might happen for `std::hypot(x,y,z)`. +// returns [overflow_threshold, overflow_scale] +template +_LIBCPP_HIDE_FROM_ABI std::pair<_Real, _Real> __hypot_factors() { + static_assert(std::numeric_limits<_Real>::is_iec559); + + if constexpr (std::is_same_v<_Real, float>) { + static_assert(-125 == std::numeric_limits<_Real>::min_exponent); + static_assert(+128 == std::numeric_limits<_Real>::max_exponent); + return {0x1.0p+62f, 0x1.0p-70f}; + } else if constexpr (std::is_same_v<_Real, double>) { + static_assert(-1021 == std::numeric_limits<_Real>::min_exponent); + static_assert(+1024 == std::numeric_limits<_Real>::max_exponent); + return {0x1.0p+510, 0x1.0p-600}; + } else { // long double + static_assert(std::is_same_v<_Real, long double>); + + // preprocessor guard necessary, otherwise literals (e.g. `0x1.0p+8'190l`) throw warnings even when shielded by `if + // constexpr` +# if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__ + static_assert(sizeof(_Real) == sizeof(double)); + return static_cast>(__math::__hypot_factors()); +# else + static_assert(sizeof(_Real) > sizeof(double)); + static_assert(-16381 == std::numeric_limits<_Real>::min_exponent); + static_assert(+16384 == std::numeric_limits<_Real>::max_exponent); + return {0x1.0p+8190l, 0x1.0p-9000l}; +# endif + } +} + +// Computes the three-dimensional hypotenuse: `std::hypot(x,y,z)`. +// The naive implementation might over-/underflow which is why this implementation is more involved: +// If the square of an argument might run into issues, we scale the arguments appropriately. +// See https://github.com/llvm/llvm-project/issues/92782 for a detailed discussion and summary. +template +_LIBCPP_HIDE_FROM_ABI _Real __hypot(_Real __x, _Real __y, _Real __z) { + const _Real __max_abs = std::max(__math::fabs(__x), std::max(__math::fabs(__y), __math::fabs(__z))); + const auto [__overflow_threshold, __overflow_scale] = __math::__hypot_factors<_Real>(); + _Real __scale; + if (__max_abs > __overflow_threshold) { // x*x + y*y + z*z might overflow + __scale = __overflow_scale; + __x *= __scale; + __y *= __scale; + __z *= __scale; + } else if (__max_abs < 1 / __overflow_threshold) { // x*x + y*y + z*z might underflow + __scale = 1 / __overflow_scale; + __x *= __scale; + __y *= __scale; + __z *= __scale; + } else + __scale = 1; + return __math::sqrt(__x * __x + __y * __y + __z * __z) / __scale; +} + +inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) { return __math::__hypot(__x, __y, __z); } + +inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) { return __math::__hypot(__x, __y, __z); } + +inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) { + return __math::__hypot(__x, __y, __z); +} + +template && is_arithmetic_v<_A2> && is_arithmetic_v<_A3>, int> = 0 > +_LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2, _A3>::type hypot(_A1 __x, _A2 __y, _A3 __z) _NOEXCEPT { + using __result_type = typename __promote<_A1, _A2, _A3>::type; + static_assert(!( + std::is_same_v<_A1, __result_type> && std::is_same_v<_A2, __result_type> && std::is_same_v<_A3, __result_type>)); + return __math::__hypot( + static_cast<__result_type>(__x), static_cast<__result_type>(__y), static_cast<__result_type>(__z)); +} +#endif + } // namespace __math _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS #endif // _LIBCPP___MATH_HYPOT_H diff --git a/libcxx/include/cmath b/libcxx/include/cmath index 3c22604a683c33..6480c4678ce33d 100644 --- a/libcxx/include/cmath +++ b/libcxx/include/cmath @@ -313,6 +313,7 @@ constexpr long double lerp(long double a, long double b, long double t) noexcept */ #include <__config> +#include <__math/hypot.h> #include <__type_traits/enable_if.h> #include <__type_traits/is_arithmetic.h> #include <__type_traits/is_constant_evaluated.h> @@ -553,30 +554,6 @@ using ::scalbnl _LIBCPP_USING_IF_EXISTS; using ::tgammal _LIBCPP_USING_IF_EXISTS; using ::truncl _LIBCPP_USING_IF_EXISTS; -#if _LIBCPP_STD_VER >= 17 -inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) { - return sqrt(__x * __x + __y * __y + __z * __z); -} -inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) { - return sqrt(__x * __x + __y * __y + __z * __z); -} -inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) { - return sqrt(__x * __x + __y * __y + __z * __z); -} - -template -inline _LIBCPP_HIDE_FROM_ABI -typename enable_if_t< is_arithmetic<_A1>::value && is_arithmetic<_A2>::value && is_arithmetic<_A3>::value, - __promote<_A1, _A2, _A3> >::type -hypot(_A1 __lcpp_x, _A2 __lcpp_y, _A3 __lcpp_z) _NOEXCEPT { - typedef typename __promote<_A1, _A2, _A3>::type __result_type; - static_assert( - !(is_same<_A1, __result_type>::value && is_same<_A2, __result_type>::value && is_same<_A3, __result_type>::value), - ""); - return std::hypot((__result_type)__lcpp_x, (__result_type)__lcpp_y, (__result_type)__lcpp_z); -} -#endif - template ::value, int> = 0> _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR bool __constexpr_isnan(_A1 __lcpp_x) _NOEXCEPT { #if __has_builtin(__builtin_isnan) diff --git a/libcxx/test/libcxx/transitive_includes/cxx17.csv b/libcxx/test/libcxx/transitive_includes/cxx17.csv index 2c028462144eee..8099d2b79c4bee 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx17.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx17.csv @@ -130,6 +130,9 @@ chrono type_traits chrono vector chrono version cinttypes cstdint +cmath cstddef +cmath cstdint +cmath initializer_list cmath limits cmath type_traits cmath version diff --git a/libcxx/test/libcxx/transitive_includes/cxx20.csv b/libcxx/test/libcxx/transitive_includes/cxx20.csv index 982c2013e34170..384e51b101f311 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx20.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx20.csv @@ -135,6 +135,9 @@ chrono type_traits chrono vector chrono version cinttypes cstdint +cmath cstddef +cmath cstdint +cmath initializer_list cmath limits cmath type_traits cmath version diff --git a/libcxx/test/libcxx/transitive_includes/cxx23.csv b/libcxx/test/libcxx/transitive_includes/cxx23.csv index 8ffb71d8b566b0..46b833d143f39a 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx23.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx23.csv @@ -83,6 +83,9 @@ chrono string_view chrono vector chrono version cinttypes cstdint +cmath cstddef +cmath cstdint +cmath initializer_list cmath limits cmath version codecvt cctype diff --git a/libcxx/test/libcxx/transitive_includes/cxx26.csv b/libcxx/test/libcxx/transitive_includes/cxx26.csv index 8ffb71d8b566b0..46b833d143f39a 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx26.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx26.csv @@ -83,6 +83,9 @@ chrono string_view chrono vector chrono version cinttypes cstdint +cmath cstddef +cmath cstdint +cmath initializer_list cmath limits cmath version codecvt cctype diff --git a/libcxx/test/std/numerics/c.math/cmath.pass.cpp b/libcxx/test/std/numerics/c.math/cmath.pass.cpp index 93790844997923..19b5fd0cf89966 100644 --- a/libcxx/test/std/numerics/c.math/cmath.pass.cpp +++ b/libcxx/test/std/numerics/c.math/cmath.pass.cpp @@ -12,14 +12,17 @@ // +#include #include #include #include #include +#include "fp_compare.h" #include "test_macros.h" #include "hexfloat.h" #include "truncate_fp.h" +#include "type_algorithms.h" // convertible to int/float/double/etc template @@ -1113,6 +1116,56 @@ void test_fmin() assert(std::fmin(1,0) == 0); } +#if TEST_STD_VER >= 17 +struct TestHypot3 { + template + void operator()() const { + const auto check = [](Real elem, Real abs_tol) { + assert(std::isfinite(std::hypot(elem, Real(0), Real(0)))); + assert(fptest_close(std::hypot(elem, Real(0), Real(0)), elem, abs_tol)); + assert(std::isfinite(std::hypot(elem, elem, Real(0)))); + assert(fptest_close(std::hypot(elem, elem, Real(0)), std::sqrt(Real(2)) * elem, abs_tol)); + assert(std::isfinite(std::hypot(elem, elem, elem))); + assert(fptest_close(std::hypot(elem, elem, elem), std::sqrt(Real(3)) * elem, abs_tol)); + }; + + { // check for overflow + const auto [elem, abs_tol] = []() -> std::array { + if constexpr (std::is_same_v) + return {1e20f, 1e16f}; + else if constexpr (std::is_same_v) + return {1e300, 1e287}; + else { // long double +# if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__ + return {1e300l, 1e287l}; // 64-bit +# else + return {1e4000l, 1e3985l}; // 80- or 128-bit +# endif + } + }(); + check(elem, abs_tol); + } + + { // check for underflow + const auto [elem, abs_tol] = []() -> std::array { + if constexpr (std::is_same_v) + return {1e-20f, 1e-24f}; + else if constexpr (std::is_same_v) + return {1e-287, 1e-300}; + else { // long double +# if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__ + return {1e-287l, 1e-300l}; // 64-bit +# else + return {1e-3985l, 1e-4000l}; // 80- or 128-bit +# endif + } + }(); + check(elem, abs_tol); + } + } +}; +#endif + void test_hypot() { static_assert((std::is_same::value), ""); @@ -1135,25 +1188,31 @@ void test_hypot() static_assert((std::is_same::value), ""); assert(std::hypot(3,4) == 5); -#if TEST_STD_VER > 14 - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); - static_assert((std::is_same::value), ""); +#if TEST_STD_VER >= 17 + // clang-format off + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + static_assert((std::is_same_v)); + // clang-format on assert(std::hypot(2,3,6) == 7); assert(std::hypot(1,4,8) == 9); + + // Check for undue over-/underflows of intermediate results. + // See discussion at https://github.com/llvm/llvm-project/issues/92782. + types::for_each(types::floating_point_types(), TestHypot3()); #endif } diff --git a/libcxx/test/support/fp_compare.h b/libcxx/test/support/fp_compare.h index 1d1933b0bcd813..3088a211dadc3b 100644 --- a/libcxx/test/support/fp_compare.h +++ b/libcxx/test/support/fp_compare.h @@ -9,39 +9,34 @@ #ifndef SUPPORT_FP_COMPARE_H #define SUPPORT_FP_COMPARE_H -#include // for std::abs -#include // for std::max +#include // for std::abs +#include // for std::max #include +#include <__config> // See https://www.boost.org/doc/libs/1_70_0/libs/test/doc/html/boost_test/testing_tools/extended_comparison/floating_point/floating_points_comparison_theory.html -template -bool fptest_close(T val, T expected, T eps) -{ - constexpr T zero = T(0); - assert(eps >= zero); +template +bool fptest_close(T val, T expected, T eps) { + _LIBCPP_CONSTEXPR T zero = T(0); + assert(eps >= zero); - // Handle the zero cases - if (eps == zero) return val == expected; - if (val == zero) return std::abs(expected) <= eps; - if (expected == zero) return std::abs(val) <= eps; + // Handle the zero cases + if (eps == zero) + return val == expected; + if (val == zero) + return std::abs(expected) <= eps; + if (expected == zero) + return std::abs(val) <= eps; - return std::abs(val - expected) < eps - && std::abs(val - expected)/std::abs(val) < eps; + return std::abs(val - expected) < eps && std::abs(val - expected) / std::abs(val) < eps; } -template -bool fptest_close_pct(T val, T expected, T percent) -{ - constexpr T zero = T(0); - assert(percent >= zero); - - // Handle the zero cases - if (percent == zero) return val == expected; - T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected)); - - return fptest_close(val, expected, eps); +template +bool fptest_close_pct(T val, T expected, T percent) { + assert(percent >= T(0)); + T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected)); + return fptest_close(val, expected, eps); } - #endif // SUPPORT_FP_COMPARE_H