Skip to content

Commit

Permalink
Generic, simple implementation fox xsimd::expand
Browse files Browse the repository at this point in the history
Related to #975
  • Loading branch information
serge-sans-paille committed Nov 21, 2023
1 parent 0f47da8 commit 8dd0e7a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/xsimd/arch/generic/xsimd_generic_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ namespace xsimd

using namespace types;

// expand
template <typename A, typename T>
inline batch<T, A>
expand(batch<T, A> const& x, batch_bool<T, A> const& mask,
kernel::requires_arch<generic>) noexcept
{
constexpr std::size_t size = batch_bool<T, A>::size;
alignas(A::alignment()) bool mask_buffer[size];
mask.store_aligned(&mask_buffer[0]);
alignas(A::alignment()) as_integer_t<T> swizzle_buffer[size];
for (size_t i = 0, j = 0; i < size; ++i)
{
swizzle_buffer[i] = mask_buffer[i] ? j++ : 0;
}
auto swizzle_mask = batch<as_integer_t<T>, A>::load_aligned(&swizzle_buffer[0]);
return select(mask, swizzle(x, swizzle_mask), batch<T, A>(T(0)));
}

// extract_pair
template <class A, class T>
inline batch<T, A> extract_pair(batch<T, A> const& self, batch<T, A> const& other, std::size_t i, requires_arch<generic>) noexcept
Expand Down
13 changes: 13 additions & 0 deletions include/xsimd/types/xsimd_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,19 @@ namespace xsimd
return kernel::exp2<A>(x, A {});
}

/**
* @ingroup batch_data_transfer
*
* Load contiguous elements from \c x and place them in slots selected by \c
* mask, zeroing the other slots
*/
template <class T, class A>
inline batch<T, A> expand(batch<T, A> const& x, batch_bool<T, A> const& mask) noexcept
{
detail::static_check_supported_config<T, A>();
return kernel::expand<A>(x, mask, A {});
}

/**
* @ingroup batch_math
*
Expand Down
96 changes: 96 additions & 0 deletions test/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,102 @@ TEST_CASE_TEMPLATE("[slide]", B, BATCH_INT_TYPES)

#endif

template <class B>
struct expand_test
{
using batch_type = B;
using value_type = typename B::value_type;
using mask_batch_type = typename B::batch_bool_type;

static constexpr size_t size = B::size;
std::array<value_type, size> input;
std::array<bool, size> mask;
std::array<value_type, size> expected;

expand_test()
{
for (size_t i = 0; i < size; ++i)
{
input[i] = i;
}
}

void full()
{
std::fill(mask.begin(), mask.end(), true);

for (size_t i = 0; i < size; ++i)
expected[i] = input[i];

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void empty()
{
std::fill(mask.begin(), mask.end(), false);

for (size_t i = 0; i < size; ++i)
expected[i] = 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void interleave()
{
for (size_t i = 0; i < size; ++i)
mask[i] = i % 2 == 0;

for (size_t i = 0, j = 0; i < size; ++i)
expected[i] = mask[i] ? input[j++] : 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void generic()
{
for (size_t i = 0; i < size; ++i)
mask[i] = i % 3 == 0;

for (size_t i = 0, j = 0; i < size; ++i)
expected[i] = mask[i] ? input[j++] : 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}
};

TEST_CASE_TEMPLATE("[expand]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, xsimd::batch<int32_t>, xsimd::batch<uint64_t>, xsimd::batch<int64_t>)
{
expand_test<B> Test;
SUBCASE("empty")
{
Test.empty();
}
SUBCASE("full")
{
Test.full();
}
SUBCASE("interleave")
{
Test.interleave();
}
SUBCASE("generic")
{
Test.generic();
}
}

template <class B>
struct shuffle_test
{
Expand Down

0 comments on commit 8dd0e7a

Please sign in to comment.