Skip to content

Commit

Permalink
Support basic operations on batch constant
Browse files Browse the repository at this point in the history
Add support for &&, ||, ^ and ! for batch_bool_constant.
Add support for +, -, *, / and unary - for batch_constant.

Fix #930
  • Loading branch information
serge-sans-paille committed Oct 11, 2023
1 parent b0668e4 commit 9dc3d60
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 3 deletions.
141 changes: 141 additions & 0 deletions include/xsimd/types/xsimd_batch_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace xsimd
template <class batch_type, bool... Values>
struct batch_bool_constant
{

public:
static constexpr std::size_t size = sizeof...(Values);
using arch_type = typename batch_type::arch_type;
using value_type = bool;
Expand All @@ -47,11 +49,67 @@ namespace xsimd

private:
static constexpr int mask_helper(int acc) noexcept { return acc; }

template <class... Tys>
static constexpr int mask_helper(int acc, int mask, Tys... masks) noexcept
{
return mask_helper(acc | mask, (masks << 1)...);
}

struct logical_or
{
constexpr bool operator()(bool x, bool y) const { return x || y; }
};
struct logical_and
{
constexpr bool operator()(bool x, bool y) const { return x && y; }
};
struct logical_xor
{
constexpr bool operator()(bool x, bool y) const { return x ^ y; }
};

template <class F, class SelfPack, class OtherPack, size_t... Indices>
static constexpr batch_bool_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
apply(detail::index_sequence<Indices...>)
{
return {};
}

template <class F, bool... OtherValues>
static constexpr auto apply(batch_bool_constant<batch_type, Values...>, batch_bool_constant<batch_type, OtherValues...>)
-> decltype(apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
{
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
return apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
}

public:
#define MAKE_BINARY_OP(OP, NAME) \
template <bool... OtherValues> \
constexpr auto operator OP(batch_bool_constant<batch_type, OtherValues...> other) const \
->decltype(apply<NAME>(*this, other)) \
{ \
return apply<NAME>(*this, other); \
}

MAKE_BINARY_OP(|, logical_or)
MAKE_BINARY_OP(||, logical_or)
MAKE_BINARY_OP(&, logical_and)
MAKE_BINARY_OP(&&, logical_and)
MAKE_BINARY_OP(^, logical_xor)

#undef MAKE_BINARY_OP

constexpr batch_bool_constant<batch_type, !Values...> operator!() const
{
return {};
}

constexpr batch_bool_constant<batch_type, !Values...> operator~() const
{
return {};
}
};

/**
Expand Down Expand Up @@ -88,6 +146,89 @@ namespace xsimd
{
return values[i];
}

struct arithmetic_add
{
constexpr value_type operator()(value_type x, value_type y) const { return x + y; }
};
struct arithmetic_sub
{
constexpr value_type operator()(value_type x, value_type y) const { return x - y; }
};
struct arithmetic_mul
{
constexpr value_type operator()(value_type x, value_type y) const { return x * y; }
};
struct arithmetic_div
{
constexpr value_type operator()(value_type x, value_type y) const { return x / y; }
};
struct arithmetic_mod
{
constexpr value_type operator()(value_type x, value_type y) const { return x % y; }
};
struct binary_and
{
constexpr value_type operator()(value_type x, value_type y) const { return x & y; }
};
struct binary_or
{
constexpr value_type operator()(value_type x, value_type y) const { return x | y; }
};
struct binary_xor
{
constexpr value_type operator()(value_type x, value_type y) const { return x ^ y; }
};

template <class F, class SelfPack, class OtherPack, size_t... Indices>
static constexpr batch_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
apply(detail::index_sequence<Indices...>)
{
return {};
}

template <class F, value_type... OtherValues>
static constexpr auto apply(batch_constant<batch_type, Values...>, batch_constant<batch_type, OtherValues...>)
-> decltype(apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
{
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
return apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
}

public:
#define MAKE_BINARY_OP(OP, NAME) \
template <value_type... OtherValues> \
constexpr auto operator OP(batch_constant<batch_type, OtherValues...> other) const \
->decltype(apply<NAME>(*this, other)) \
{ \
return apply<NAME>(*this, other); \
}

MAKE_BINARY_OP(+, arithmetic_add)
MAKE_BINARY_OP(-, arithmetic_sub)
MAKE_BINARY_OP(*, arithmetic_mul)
MAKE_BINARY_OP(/, arithmetic_div)
MAKE_BINARY_OP(%, arithmetic_mod)
MAKE_BINARY_OP(&, binary_and)
MAKE_BINARY_OP(|, binary_or)
MAKE_BINARY_OP(^, binary_xor)

#undef MAKE_BINARY_OP

constexpr batch_constant<batch_type, (value_type)-Values...> operator-() const
{
return {};
}

constexpr batch_constant<batch_type, (value_type) + Values...> operator+() const
{
return {};
}

constexpr batch_constant<batch_type, (value_type)~Values...> operator~() const
{
return {};
}
};

namespace detail
Expand Down
109 changes: 106 additions & 3 deletions test/test_batch_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,69 @@ struct constant_batch_test
CHECK_BATCH_EQ((batch_type)b, expected);
}

template <value_type V>
struct constant
{
static constexpr value_type get(size_t /*index*/, size_t /*size*/)
{
return 3;
return V;
}
};

void test_init_from_constant() const
{
array_type expected;
std::fill(expected.begin(), expected.end(), constant::get(0, 0));
constexpr auto b = xsimd::make_batch_constant<batch_type, constant>();
std::fill(expected.begin(), expected.end(), constant<3>::get(0, 0));
constexpr auto b = xsimd::make_batch_constant<batch_type, constant<3>>();
INFO("batch(value_type)");
CHECK_BATCH_EQ((batch_type)b, expected);
}

void test_ops() const
{
constexpr auto n12 = xsimd::make_batch_constant<batch_type, constant<12>>();
constexpr auto n3 = xsimd::make_batch_constant<batch_type, constant<3>>();

constexpr auto n12_add_n3 = n12 + n3;
constexpr auto n15 = xsimd::make_batch_constant<batch_type, constant<15>>();
static_assert(std::is_same<decltype(n12_add_n3), decltype(n15)>::value, "n12 + n3 == n15");

constexpr auto n12_sub_n3 = n12 - n3;
constexpr auto n9 = xsimd::make_batch_constant<batch_type, constant<9>>();
static_assert(std::is_same<decltype(n12_sub_n3), decltype(n9)>::value, "n12 - n3 == n9");

constexpr auto n12_mul_n3 = n12 * n3;
constexpr auto n36 = xsimd::make_batch_constant<batch_type, constant<36>>();
static_assert(std::is_same<decltype(n12_mul_n3), decltype(n36)>::value, "n12 * n3 == n36");

constexpr auto n12_div_n3 = n12 / n3;
constexpr auto n4 = xsimd::make_batch_constant<batch_type, constant<4>>();
static_assert(std::is_same<decltype(n12_div_n3), decltype(n4)>::value, "n12 / n3 == n4");

constexpr auto n12_mod_n3 = n12 % n3;
constexpr auto n0 = xsimd::make_batch_constant<batch_type, constant<0>>();
static_assert(std::is_same<decltype(n12_mod_n3), decltype(n0)>::value, "n12 % n3 == n0");

constexpr auto n12_land_n3 = n12 & n3;
static_assert(std::is_same<decltype(n12_land_n3), decltype(n0)>::value, "n12 & n3 == n0");

constexpr auto n12_lor_n3 = n12 | n3;
static_assert(std::is_same<decltype(n12_lor_n3), decltype(n15)>::value, "n12 | n3 == n15");

constexpr auto n12_lxor_n3 = n12 ^ n3;
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");

constexpr auto n12_uadd = +n12;
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");

constexpr auto n12_inv = ~n12;
constexpr auto n12_inv_ = xsimd::make_batch_constant<batch_type, constant<(value_type)~12>>();
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "~n12 == n12_inv");

constexpr auto n12_usub = -n12;
constexpr auto n12_usub_ = xsimd::make_batch_constant<batch_type, constant<(value_type)-12>>();
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "-n12 == n12_usub");
}
};

TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
Expand All @@ -93,6 +140,11 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
}

SUBCASE("init_from_constant") { Test.test_init_from_constant(); }

SUBCASE("operators")
{
Test.test_ops();
}
}

template <class B>
Expand Down Expand Up @@ -144,6 +196,53 @@ struct constant_bool_batch_test
INFO("batch_bool_constant(value_type)");
CHECK_BATCH_EQ((batch_bool_type)b, expected);
}

struct inv_split
{
static constexpr bool get(size_t index, size_t size)
{
return !split().get(index, size);
}
};

template <bool Val>
struct constant
{
static constexpr bool get(size_t /*index*/, size_t /*size*/)
{
return Val;
}
};

void test_ops() const
{
constexpr auto all_true = xsimd::make_batch_bool_constant<batch_type, constant<true>>();
constexpr auto all_false = xsimd::make_batch_bool_constant<batch_type, constant<false>>();

constexpr auto x = xsimd::make_batch_bool_constant<batch_type, split>();
constexpr auto y = xsimd::make_batch_bool_constant<batch_type, inv_split>();

constexpr auto x_or_y = x | y;
static_assert(std::is_same<decltype(x_or_y), decltype(all_true)>::value, "x | y == true");

constexpr auto x_lor_y = x || y;
static_assert(std::is_same<decltype(x_lor_y), decltype(all_true)>::value, "x || y == true");

constexpr auto x_and_y = x & y;
static_assert(std::is_same<decltype(x_and_y), decltype(all_false)>::value, "x & y == false");

constexpr auto x_land_y = x && y;
static_assert(std::is_same<decltype(x_land_y), decltype(all_false)>::value, "x && y == false");

constexpr auto x_xor_y = x ^ y;
static_assert(std::is_same<decltype(x_xor_y), decltype(all_true)>::value, "x ^ y == true");

constexpr auto not_x = !x;
static_assert(std::is_same<decltype(not_x), decltype(y)>::value, "!x == y");

constexpr auto inv_x = ~x;
static_assert(std::is_same<decltype(inv_x), decltype(y)>::value, "~x == y");
}
};

TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
Expand All @@ -155,5 +254,9 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
{
Test.test_init_from_generator_split();
}
SUBCASE("operators")
{
Test.test_ops();
}
}
#endif

0 comments on commit 9dc3d60

Please sign in to comment.