diff --git a/include/xsimd/types/xsimd_batch_constant.hpp b/include/xsimd/types/xsimd_batch_constant.hpp index ce56dab44..bf2b9569e 100644 --- a/include/xsimd/types/xsimd_batch_constant.hpp +++ b/include/xsimd/types/xsimd_batch_constant.hpp @@ -28,6 +28,8 @@ namespace xsimd template struct batch_bool_constant { + + public: static constexpr std::size_t size = sizeof...(Values); using arch_type = typename batch_type::arch_type; using value_type = bool; @@ -47,11 +49,67 @@ namespace xsimd private: static constexpr int mask_helper(int acc) noexcept { return acc; } + template static constexpr int mask_helper(int acc, int mask, Tys... masks) noexcept { return mask_helper(acc | mask, (masks << 1)...); } + + struct logical_or + { + constexpr bool operator()(bool x, bool y) const { return x || y; } + }; + struct logical_and + { + constexpr bool operator()(bool x, bool y) const { return x && y; } + }; + struct logical_xor + { + constexpr bool operator()(bool x, bool y) const { return x ^ y; } + }; + + template + static constexpr batch_bool_constant::type::value, std::tuple_element::type::value)...> + apply(detail::index_sequence) + { + return {}; + } + + template + static constexpr auto apply(batch_bool_constant, batch_bool_constant) + -> decltype(apply...>, std::tuple...>>(detail::make_index_sequence())) + { + static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches"); + return apply...>, std::tuple...>>(detail::make_index_sequence()); + } + + public: +#define MAKE_BINARY_OP(OP, NAME) \ + template \ + constexpr auto operator OP(batch_bool_constant other) const \ + ->decltype(apply(*this, other)) \ + { \ + return apply(*this, other); \ + } + + MAKE_BINARY_OP(|, logical_or) + MAKE_BINARY_OP(||, logical_or) + MAKE_BINARY_OP(&, logical_and) + MAKE_BINARY_OP(&&, logical_and) + MAKE_BINARY_OP(^, logical_xor) + +#undef MAKE_BINARY_OP + + constexpr batch_bool_constant operator!() const + { + return {}; + } + + constexpr batch_bool_constant operator~() const + { + return {}; + } }; /** @@ -88,6 +146,89 @@ namespace xsimd { return values[i]; } + + struct arithmetic_add + { + constexpr value_type operator()(value_type x, value_type y) const { return x + y; } + }; + struct arithmetic_sub + { + constexpr value_type operator()(value_type x, value_type y) const { return x - y; } + }; + struct arithmetic_mul + { + constexpr value_type operator()(value_type x, value_type y) const { return x * y; } + }; + struct arithmetic_div + { + constexpr value_type operator()(value_type x, value_type y) const { return x / y; } + }; + struct arithmetic_mod + { + constexpr value_type operator()(value_type x, value_type y) const { return x % y; } + }; + struct binary_and + { + constexpr value_type operator()(value_type x, value_type y) const { return x & y; } + }; + struct binary_or + { + constexpr value_type operator()(value_type x, value_type y) const { return x | y; } + }; + struct binary_xor + { + constexpr value_type operator()(value_type x, value_type y) const { return x ^ y; } + }; + + template + static constexpr batch_constant::type::value, std::tuple_element::type::value)...> + apply(detail::index_sequence) + { + return {}; + } + + template + static constexpr auto apply(batch_constant, batch_constant) + -> decltype(apply...>, std::tuple...>>(detail::make_index_sequence())) + { + static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches"); + return apply...>, std::tuple...>>(detail::make_index_sequence()); + } + + public: +#define MAKE_BINARY_OP(OP, NAME) \ + template \ + constexpr auto operator OP(batch_constant other) const \ + ->decltype(apply(*this, other)) \ + { \ + return apply(*this, other); \ + } + + MAKE_BINARY_OP(+, arithmetic_add) + MAKE_BINARY_OP(-, arithmetic_sub) + MAKE_BINARY_OP(*, arithmetic_mul) + MAKE_BINARY_OP(/, arithmetic_div) + MAKE_BINARY_OP(%, arithmetic_mod) + MAKE_BINARY_OP(&, binary_and) + MAKE_BINARY_OP(|, binary_or) + MAKE_BINARY_OP(^, binary_xor) + +#undef MAKE_BINARY_OP + + constexpr batch_constant operator-() const + { + return {}; + } + + constexpr batch_constant operator+() const + { + return {}; + } + + constexpr batch_constant operator~() const + { + return {}; + } }; namespace detail diff --git a/test/test_batch_constant.cpp b/test/test_batch_constant.cpp index 27f87017b..76e64c78c 100644 --- a/test/test_batch_constant.cpp +++ b/test/test_batch_constant.cpp @@ -64,22 +64,69 @@ struct constant_batch_test CHECK_BATCH_EQ((batch_type)b, expected); } + template struct constant { static constexpr value_type get(size_t /*index*/, size_t /*size*/) { - return 3; + return V; } }; void test_init_from_constant() const { array_type expected; - std::fill(expected.begin(), expected.end(), constant::get(0, 0)); - constexpr auto b = xsimd::make_batch_constant(); + std::fill(expected.begin(), expected.end(), constant<3>::get(0, 0)); + constexpr auto b = xsimd::make_batch_constant>(); INFO("batch(value_type)"); CHECK_BATCH_EQ((batch_type)b, expected); } + + void test_ops() const + { + constexpr auto n12 = xsimd::make_batch_constant>(); + constexpr auto n3 = xsimd::make_batch_constant>(); + + constexpr auto n12_add_n3 = n12 + n3; + constexpr auto n15 = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "n12 + n3 == n15"); + + constexpr auto n12_sub_n3 = n12 - n3; + constexpr auto n9 = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "n12 - n3 == n9"); + + constexpr auto n12_mul_n3 = n12 * n3; + constexpr auto n36 = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "n12 * n3 == n36"); + + constexpr auto n12_div_n3 = n12 / n3; + constexpr auto n4 = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "n12 / n3 == n4"); + + constexpr auto n12_mod_n3 = n12 % n3; + constexpr auto n0 = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "n12 % n3 == n0"); + + constexpr auto n12_land_n3 = n12 & n3; + static_assert(std::is_same::value, "n12 & n3 == n0"); + + constexpr auto n12_lor_n3 = n12 | n3; + static_assert(std::is_same::value, "n12 | n3 == n15"); + + constexpr auto n12_lxor_n3 = n12 ^ n3; + static_assert(std::is_same::value, "n12 ^ n3 == n15"); + + constexpr auto n12_uadd = +n12; + static_assert(std::is_same::value, "+n12 == n12"); + + constexpr auto n12_inv = ~n12; + constexpr auto n12_inv_ = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "~n12 == n12_inv"); + + constexpr auto n12_usub = -n12; + constexpr auto n12_usub_ = xsimd::make_batch_constant>(); + static_assert(std::is_same::value, "-n12 == n12_usub"); + } }; TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES) @@ -93,6 +140,11 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES) } SUBCASE("init_from_constant") { Test.test_init_from_constant(); } + + SUBCASE("operators") + { + Test.test_ops(); + } } template @@ -144,6 +196,53 @@ struct constant_bool_batch_test INFO("batch_bool_constant(value_type)"); CHECK_BATCH_EQ((batch_bool_type)b, expected); } + + struct inv_split + { + static constexpr bool get(size_t index, size_t size) + { + return !split().get(index, size); + } + }; + + template + struct constant + { + static constexpr bool get(size_t /*index*/, size_t /*size*/) + { + return Val; + } + }; + + void test_ops() const + { + constexpr auto all_true = xsimd::make_batch_bool_constant>(); + constexpr auto all_false = xsimd::make_batch_bool_constant>(); + + constexpr auto x = xsimd::make_batch_bool_constant(); + constexpr auto y = xsimd::make_batch_bool_constant(); + + constexpr auto x_or_y = x | y; + static_assert(std::is_same::value, "x | y == true"); + + constexpr auto x_lor_y = x || y; + static_assert(std::is_same::value, "x || y == true"); + + constexpr auto x_and_y = x & y; + static_assert(std::is_same::value, "x & y == false"); + + constexpr auto x_land_y = x && y; + static_assert(std::is_same::value, "x && y == false"); + + constexpr auto x_xor_y = x ^ y; + static_assert(std::is_same::value, "x ^ y == true"); + + constexpr auto not_x = !x; + static_assert(std::is_same::value, "!x == y"); + + constexpr auto inv_x = ~x; + static_assert(std::is_same::value, "~x == y"); + } }; TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES) @@ -155,5 +254,9 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES) { Test.test_init_from_generator_split(); } + SUBCASE("operators") + { + Test.test_ops(); + } } #endif