From 2bbeb9c7a6ebab1dd14347fa6a2f0020be9781a9 Mon Sep 17 00:00:00 2001 From: serge-sans-paille Date: Mon, 26 Feb 2024 19:49:24 +0100 Subject: [PATCH] avg --- .../xsimd/arch/generic/xsimd_generic_math.hpp | 34 ++++++++++++++++++- include/xsimd/arch/xsimd_avx2.hpp | 17 ++++++++++ include/xsimd/arch/xsimd_avx512bw.hpp | 17 ++++++++++ include/xsimd/arch/xsimd_sse2.hpp | 20 +++++++++++ include/xsimd/types/xsimd_api.hpp | 15 ++++++++ test/test_batch.cpp | 24 +++++++++++++ test/test_xsimd_api.cpp | 7 ++++ 7 files changed, 133 insertions(+), 1 deletion(-) diff --git a/include/xsimd/arch/generic/xsimd_generic_math.hpp b/include/xsimd/arch/generic/xsimd_generic_math.hpp index 8fa887dc5..d99196f61 100644 --- a/include/xsimd/arch/generic/xsimd_generic_math.hpp +++ b/include/xsimd/arch/generic/xsimd_generic_math.hpp @@ -26,7 +26,7 @@ namespace xsimd using namespace types; // abs - template ::value, void>::type*/> + template inline batch abs(batch const& self, requires_arch) noexcept { if (std::is_unsigned::value) @@ -45,6 +45,38 @@ namespace xsimd return hypot(z.real(), z.imag()); } + // avg + namespace detail { + template + inline batch avg(batch const& x, batch const& y, std::true_type, std::false_type) noexcept + { + return (incr(x) + y) >> 1; + } + + template + inline batch avg(batch const& x, batch const& y, std::true_type, std::true_type) noexcept + { + using uT = typename std::make_unsigned::type; + // Inspired by + // https://stackoverflow.com/questions/5697500/take-the-average-of-two-signed-numbers-in-c + return select((x ^ y) < 0, + (x + y) / 2, + bitwise_cast(avg(decr(bitwise_cast(x)), bitwise_cast(y))) + ); + } + + template + inline batch avg(batch const& x, batch const& y, std::false_type, std::true_type) noexcept + { + return (x + y) / 2; + } + } + + template + inline batch avg(batch const& x, batch const& y, requires_arch) noexcept { + return detail::avg(x, y, typename std::is_integral::type{}, typename std::is_signed::type{}); + } + // batch_cast template inline batch batch_cast(batch const& self, batch const&, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index a5b07ec9d..452ac4855 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -76,6 +76,23 @@ namespace xsimd } } + // avg + template ::value, void>::type> + inline batch avg(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_avg_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_avg_epu16(self, other); + } + else { + return avg(self, other, generic{}); + } + } + // bitwise_and template ::value, void>::type> inline batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx512bw.hpp b/include/xsimd/arch/xsimd_avx512bw.hpp index 94a194dab..4d113a96c 100644 --- a/include/xsimd/arch/xsimd_avx512bw.hpp +++ b/include/xsimd/arch/xsimd_avx512bw.hpp @@ -112,6 +112,23 @@ namespace xsimd } } + // avg + template ::value, void>::type> + inline batch avg(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_avg_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_avg_epu16(self, other); + } + else { + return avg(self, other, generic{}); + } + } + // bitwise_lshift template ::value, void>::type> inline batch bitwise_lshift(batch const& self, int32_t other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index 0a34cb1e9..0eb4b8809 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -60,6 +60,9 @@ namespace xsimd inline batch insert(batch const& self, T val, index, requires_arch) noexcept; template inline batch shuffle(batch const& x, batch const& y, batch_constant, Indices...>, requires_arch) noexcept; + template + inline batch avg(batch const& , batch const&, requires_arch) noexcept; + // abs template @@ -148,6 +151,23 @@ namespace xsimd return _mm_movemask_epi8(self) != 0; } + // avg + template ::value, void>::type> + inline batch avg(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm_avg_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm_avg_epu16(self, other); + } + else { + return avg(self, other, generic{}); + } + } + // batch_bool_cast template inline batch_bool batch_bool_cast(batch_bool const& self, batch_bool const&, requires_arch) noexcept diff --git a/include/xsimd/types/xsimd_api.hpp b/include/xsimd/types/xsimd_api.hpp index 0420f0a09..0bedf5fb4 100644 --- a/include/xsimd/types/xsimd_api.hpp +++ b/include/xsimd/types/xsimd_api.hpp @@ -202,6 +202,21 @@ namespace xsimd return kernel::atanh(x, A {}); } + /** + * @ingroup batch_math + * + * Computes the average of batches \c x and \c y + * @param x batch of T + * @param y batch of T + * @return the average of elements between \c x and \c y. + */ + template + inline batch avg(batch const& x, batch const& y) noexcept + { + detail::static_check_supported_config(); + return kernel::avg(x, y, A {}); + } + /** * @ingroup batch_conversion * diff --git a/test/test_batch.cpp b/test/test_batch.cpp index 9c3217361..ab058f08d 100644 --- a/test/test_batch.cpp +++ b/test/test_batch.cpp @@ -737,6 +737,25 @@ struct batch_test } } + void test_avg() const + { + array_type expected; + std::transform(lhs.cbegin(), lhs.cend(), rhs.cbegin(), expected.begin(), + [](const value_type &l, const value_type &r) -> value_type { + if (std::is_integral::value) { + if (std::is_signed::value) + return ((long long)l + r) / 2; + else + return (l + r + 1) / 2; + } else { + return (l + r) / 2; + } + }); + batch_type res = avg(batch_lhs(), batch_rhs()); + INFO("avg"); + CHECK_BATCH_EQ(res, expected); + } + void test_horizontal_operations() const { // reduce_add @@ -938,6 +957,11 @@ TEST_CASE_TEMPLATE("[batch]", B, BATCH_TYPES) Test.test_abs(); } + SUBCASE("avg") + { + Test.test_avg(); + } + SUBCASE("horizontal_operations") { Test.test_horizontal_operations(); diff --git a/test/test_xsimd_api.cpp b/test/test_xsimd_api.cpp index 440ef015c..f6215d656 100644 --- a/test/test_xsimd_api.cpp +++ b/test/test_xsimd_api.cpp @@ -1156,6 +1156,13 @@ struct xsimd_api_all_types_functions CHECK_EQ(extract(xsimd::add(T(val0), T(val1))), val0 + val1); } + void test_avg() + { + value_type val0(1); + value_type val1(3); + CHECK_EQ(extract(xsimd::avg(T(val0), T(val1))), (val0 + val1) / value_type(2)); + } + void test_decr() { value_type val0(1);