Skip to content

Commit

Permalink
avg
Browse files Browse the repository at this point in the history
  • Loading branch information
serge-sans-paille committed Feb 26, 2024
1 parent 836b4c3 commit 2bbeb9c
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 1 deletion.
34 changes: 33 additions & 1 deletion include/xsimd/arch/generic/xsimd_generic_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace xsimd

using namespace types;
// abs
template <class A, class T, class /*=typename std::enable_if<std::is_integral<T>::value, void>::type*/>
template <class A, class T, class>
inline batch<T, A> abs(batch<T, A> const& self, requires_arch<generic>) noexcept
{
if (std::is_unsigned<T>::value)
Expand All @@ -45,6 +45,38 @@ namespace xsimd
return hypot(z.real(), z.imag());
}

// avg
namespace detail {
template <class A, class T>
inline batch<T, A> avg(batch<T, A> const& x, batch<T, A> const& y, std::true_type, std::false_type) noexcept
{
return (incr(x) + y) >> 1;
}

template <class A, class T>
inline batch<T, A> avg(batch<T, A> const& x, batch<T, A> const& y, std::true_type, std::true_type) noexcept
{
using uT = typename std::make_unsigned<T>::type;
// Inspired by
// https://stackoverflow.com/questions/5697500/take-the-average-of-two-signed-numbers-in-c
return select((x ^ y) < 0,
(x + y) / 2,
bitwise_cast<T>(avg(decr(bitwise_cast<uT>(x)), bitwise_cast<uT>(y)))
);
}

template <class A, class T>
inline batch<T, A> avg(batch<T, A> const& x, batch<T, A> const& y, std::false_type, std::true_type) noexcept
{
return (x + y) / 2;
}
}

template <class A, class T>
inline batch<T, A> avg(batch<T, A> const& x, batch<T, A> const& y, requires_arch<generic>) noexcept {
return detail::avg(x, y, typename std::is_integral<T>::type{}, typename std::is_signed<T>::type{});
}

// batch_cast
template <class A, class T>
inline batch<T, A> batch_cast(batch<T, A> const& self, batch<T, A> const&, requires_arch<generic>) noexcept
Expand Down
17 changes: 17 additions & 0 deletions include/xsimd/arch/xsimd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ namespace xsimd
}
}

// avg
template <class A, class T, class = typename std::enable_if<std::is_unsigned<T>::value, void>::type>
inline batch<T, A> avg(batch<T, A> const& self, batch<T, A> const& other, requires_arch<avx2>) noexcept
{
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
return _mm256_avg_epu8(self, other);
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
return _mm256_avg_epu16(self, other);
}
else {
return avg(self, other, generic{});
}
}

// bitwise_and
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
inline batch<T, A> bitwise_and(batch<T, A> const& self, batch<T, A> const& other, requires_arch<avx2>) noexcept
Expand Down
17 changes: 17 additions & 0 deletions include/xsimd/arch/xsimd_avx512bw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ namespace xsimd
}
}

// avg
template <class A, class T, class = typename std::enable_if<std::is_unsigned<T>::value, void>::type>
inline batch<T, A> avg(batch<T, A> const& self, batch<T, A> const& other, requires_arch<avx512bw>) noexcept
{
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
return _mm512_avg_epu8(self, other);
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
return _mm512_avg_epu16(self, other);
}
else {
return avg(self, other, generic{});
}
}

// bitwise_lshift
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
inline batch<T, A> bitwise_lshift(batch<T, A> const& self, int32_t other, requires_arch<avx512bw>) noexcept
Expand Down
20 changes: 20 additions & 0 deletions include/xsimd/arch/xsimd_sse2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ namespace xsimd
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I>, requires_arch<generic>) noexcept;
template <class A, typename T, typename ITy, ITy... Indices>
inline batch<T, A> shuffle(batch<T, A> const& x, batch<T, A> const& y, batch_constant<batch<ITy, A>, Indices...>, requires_arch<generic>) noexcept;
template <class A, class T>
inline batch<T, A> avg(batch<T, A> const& , batch<T, A> const&, requires_arch<generic>) noexcept;


// abs
template <class A>
Expand Down Expand Up @@ -148,6 +151,23 @@ namespace xsimd
return _mm_movemask_epi8(self) != 0;
}

// avg
template <class A, class T, class = typename std::enable_if<std::is_unsigned<T>::value, void>::type>
inline batch<T, A> avg(batch<T, A> const& self, batch<T, A> const& other, requires_arch<sse2>) noexcept
{
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
return _mm_avg_epu8(self, other);
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
return _mm_avg_epu16(self, other);
}
else {
return avg(self, other, generic{});
}
}

// batch_bool_cast
template <class A, class T_out, class T_in>
inline batch_bool<T_out, A> batch_bool_cast(batch_bool<T_in, A> const& self, batch_bool<T_out, A> const&, requires_arch<sse2>) noexcept
Expand Down
15 changes: 15 additions & 0 deletions include/xsimd/types/xsimd_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ namespace xsimd
return kernel::atanh<A>(x, A {});
}

/**
* @ingroup batch_math
*
* Computes the average of batches \c x and \c y
* @param x batch of T
* @param y batch of T
* @return the average of elements between \c x and \c y.
*/
template <class T, class A>
inline batch<T, A> avg(batch<T, A> const& x, batch<T, A> const& y) noexcept
{
detail::static_check_supported_config<T, A>();
return kernel::avg<A>(x, y, A {});
}

/**
* @ingroup batch_conversion
*
Expand Down
24 changes: 24 additions & 0 deletions test/test_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,25 @@ struct batch_test
}
}

void test_avg() const
{
array_type expected;
std::transform(lhs.cbegin(), lhs.cend(), rhs.cbegin(), expected.begin(),
[](const value_type &l, const value_type &r) -> value_type {
if (std::is_integral<value_type>::value) {
if (std::is_signed<value_type>::value)
return ((long long)l + r) / 2;
else
return (l + r + 1) / 2;
} else {
return (l + r) / 2;
}
});
batch_type res = avg(batch_lhs(), batch_rhs());
INFO("avg");
CHECK_BATCH_EQ(res, expected);
}

void test_horizontal_operations() const
{
// reduce_add
Expand Down Expand Up @@ -938,6 +957,11 @@ TEST_CASE_TEMPLATE("[batch]", B, BATCH_TYPES)
Test.test_abs();
}

SUBCASE("avg")
{
Test.test_avg();
}

SUBCASE("horizontal_operations")
{
Test.test_horizontal_operations();
Expand Down
7 changes: 7 additions & 0 deletions test/test_xsimd_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,13 @@ struct xsimd_api_all_types_functions
CHECK_EQ(extract(xsimd::add(T(val0), T(val1))), val0 + val1);
}

void test_avg()
{
value_type val0(1);
value_type val1(3);
CHECK_EQ(extract(xsimd::avg(T(val0), T(val1))), (val0 + val1) / value_type(2));
}

void test_decr()
{
value_type val0(1);
Expand Down

0 comments on commit 2bbeb9c

Please sign in to comment.