Skip to content

Commit

Permalink
Improve SIMD vtable API (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
solidpixel authored Aug 15, 2024
1 parent 44e3b94 commit 454b6de
Show file tree
Hide file tree
Showing 10 changed files with 544 additions and 345 deletions.
183 changes: 128 additions & 55 deletions Source/UnitTest/test_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1947,43 +1947,78 @@ TEST(vmask4, not)
}

/** @brief Test vint4 table permute. */
TEST(vint4, vtable_8bt_32bi_32entry)
TEST(vint4, vtable4_16x8)
{
vint4 table0(0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f);
vint4 table1(0x10111213, 0x14151617, 0x18191a1b, 0x1c1d1e1f);
uint8_t data[16] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
};

vint4 table0p, table1p;
vtable_prepare(table0, table1, table0p, table1p);
vtable4_16x8 table;
vtable_prepare(table, data);

vint4 index(0, 7, 4, 31);
vint4 index(0, 7, 4, 15);

vint4 result = vtable_8bt_32bi(table0p, table1p, index);
vint4 result = vtable_lookup_32bit(table, index);

EXPECT_EQ(result.lane<0>(), 3);
EXPECT_EQ(result.lane<1>(), 4);
EXPECT_EQ(result.lane<2>(), 7);
EXPECT_EQ(result.lane<3>(), 28);
EXPECT_EQ(result.lane<0>(), 0);
EXPECT_EQ(result.lane<1>(), 7);
EXPECT_EQ(result.lane<2>(), 4);
EXPECT_EQ(result.lane<3>(), 15);
}

/** @brief Test vint4 table permute. */
TEST(vint4, vtable_8bt_32bi_64entry)
TEST(vint4, vtable4_32x8)
{
vint4 table0(0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f);
vint4 table1(0x10111213, 0x14151617, 0x18191a1b, 0x1c1d1e1f);
vint4 table2(0x20212223, 0x24252627, 0x28292a2b, 0x2c2d2e2f);
vint4 table3(0x30313233, 0x34353637, 0x38393a3b, 0x3c3d3e3f);
uint8_t data[32] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f
};

vtable4_32x8 table;
vtable_prepare(table, data);

vint4 index(0, 7, 4, 31);

vint4 result = vtable_lookup_32bit(table, index);

vint4 table0p, table1p, table2p, table3p;
vtable_prepare(table0, table1, table2, table3, table0p, table1p, table2p, table3p);
EXPECT_EQ(result.lane<0>(), 0);
EXPECT_EQ(result.lane<1>(), 7);
EXPECT_EQ(result.lane<2>(), 4);
EXPECT_EQ(result.lane<3>(), 31);
}

/** @brief Test vint4 table permute. */
TEST(vint4, vtable4_64x8)
{
uint8_t data[64] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f
};

vtable4_64x8 table;
vtable_prepare(table, data);

vint4 index(0, 7, 38, 63);

vint4 result = vtable_8bt_32bi(table0p, table1p, table2p, table3p, index);
vint4 result = vtable_lookup_32bit(table, index);

uint8_t* hack = reinterpret_cast<uint8_t*>(&table);
std::cout << "38: " << hack[38] << "\n";
std::cout << "63: " << hack[63] << "\n";

EXPECT_EQ(result.lane<0>(), 3);
EXPECT_EQ(result.lane<1>(), 4);
EXPECT_EQ(result.lane<2>(), 37);
EXPECT_EQ(result.lane<3>(), 60);
EXPECT_EQ(result.lane<0>(), 0);
EXPECT_EQ(result.lane<1>(), 7);
EXPECT_EQ(result.lane<2>(), 38);
EXPECT_EQ(result.lane<3>(), 63);
}

/** @brief Test vint4 rgba byte interleave. */
Expand Down Expand Up @@ -3657,57 +3692,95 @@ TEST(vmask8, not)
}

/** @brief Test vint8 table permute. */
TEST(vint8, vtable_8bt_32bi_32entry)
TEST(vint8, vtable8_16x8)
{
vint4 table0(0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f);
vint4 table1(0x10111213, 0x14151617, 0x18191a1b, 0x1c1d1e1f);
uint8_t data[16] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
};

vint8 table0p, table1p;
vtable_prepare(table0, table1, table0p, table1p);
vtable8_16x8 table;
vtable_prepare(table, data);

vint8 index = vint8_lit(0, 7, 4, 15, 16, 20, 23, 31);
vint8 index = vint8_lit(0, 7, 4, 15, 1, 2, 14, 4);

vint8 result = vtable_8bt_32bi(table0p, table1p, index);
vint8 result = vtable_lookup_32bit(table, index);

alignas(32) int ra[8];
store(result, ra);

EXPECT_EQ(ra[0], 3);
EXPECT_EQ(ra[1], 4);
EXPECT_EQ(ra[2], 7);
EXPECT_EQ(ra[3], 12);
EXPECT_EQ(ra[4], 19);
EXPECT_EQ(ra[5], 23);
EXPECT_EQ(ra[6], 20);
EXPECT_EQ(ra[7], 28);
EXPECT_EQ(ra[0], 0);
EXPECT_EQ(ra[1], 7);
EXPECT_EQ(ra[2], 4);
EXPECT_EQ(ra[3], 15);
EXPECT_EQ(ra[4], 1);
EXPECT_EQ(ra[5], 2);
EXPECT_EQ(ra[6], 14);
EXPECT_EQ(ra[7], 4);
}

/** @brief Test vint4 table permute. */
TEST(vint8, vtable_8bt_32bi_64entry)
/** @brief Test vint8 table permute. */
TEST(vint8, vtable8_32x8)
{
vint4 table0(0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f);
vint4 table1(0x10111213, 0x14151617, 0x18191a1b, 0x1c1d1e1f);
vint4 table2(0x20212223, 0x24252627, 0x28292a2b, 0x2c2d2e2f);
vint4 table3(0x30313233, 0x34353637, 0x38393a3b, 0x3c3d3e3f);
uint8_t data[32] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f
};

vtable8_32x8 table;
vtable_prepare(table, data);

vint8 index = vint8_lit(0, 7, 4, 15, 16, 20, 23, 31);

vint8 table0p, table1p, table2p, table3p;
vtable_prepare(table0, table1, table2, table3, table0p, table1p, table2p, table3p);
vint8 result = vtable_lookup_32bit(table, index);

alignas(32) int ra[8];
store(result, ra);

EXPECT_EQ(ra[0], 0);
EXPECT_EQ(ra[1], 7);
EXPECT_EQ(ra[2], 4);
EXPECT_EQ(ra[3], 15);
EXPECT_EQ(ra[4], 16);
EXPECT_EQ(ra[5], 20);
EXPECT_EQ(ra[6], 23);
EXPECT_EQ(ra[7], 31);
}

/** @brief Test vint8 table permute. */
TEST(vint8, vtable8_64x8)
{
uint8_t data[64] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f
};

vtable8_64x8 table;
vtable_prepare(table, data);

vint8 index = vint8_lit(0, 7, 4, 15, 16, 20, 38, 63);

vint8 result = vtable_8bt_32bi(table0p, table1p, table2p, table3p, index);
vint8 result = vtable_lookup_32bit(table, index);

alignas(32) int ra[8];
store(result, ra);

EXPECT_EQ(ra[0], 3);
EXPECT_EQ(ra[1], 4);
EXPECT_EQ(ra[2], 7);
EXPECT_EQ(ra[3], 12);
EXPECT_EQ(ra[4], 19);
EXPECT_EQ(ra[5], 23);
EXPECT_EQ(ra[6], 37);
EXPECT_EQ(ra[7], 60);
EXPECT_EQ(ra[0], 0);
EXPECT_EQ(ra[1], 7);
EXPECT_EQ(ra[2], 4);
EXPECT_EQ(ra[3], 15);
EXPECT_EQ(ra[4], 16);
EXPECT_EQ(ra[5], 20);
EXPECT_EQ(ra[6], 38);
EXPECT_EQ(ra[7], 63);
}

#endif
Expand Down
27 changes: 9 additions & 18 deletions Source/astcenc_decompress_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,8 @@ void unpack_weights(
if (!is_dual_plane)
{
// Build full 64-entry weight lookup table
vint4 tab0 = vint4::load(scb.weights + 0);
vint4 tab1 = vint4::load(scb.weights + 16);
vint4 tab2 = vint4::load(scb.weights + 32);
vint4 tab3 = vint4::load(scb.weights + 48);

vint tab0p, tab1p, tab2p, tab3p;
vtable_prepare(tab0, tab1, tab2, tab3, tab0p, tab1p, tab2p, tab3p);
vtable_64x8 table;
vtable_prepare(table, scb.weights);

for (unsigned int i = 0; i < bsd.texel_count; i += ASTCENC_SIMD_WIDTH)
{
Expand All @@ -118,7 +113,7 @@ void unpack_weights(
vint texel_weights(di.texel_weights_tr[j] + i);
vint texel_weights_int(di.texel_weight_contribs_int_tr[j] + i);

summed_value += vtable_8bt_32bi(tab0p, tab1p, tab2p, tab3p, texel_weights) * texel_weights_int;
summed_value += vtable_lookup_32bit(table, texel_weights) * texel_weights_int;
}

store(lsr<4>(summed_value), weights_plane1 + i);
Expand All @@ -128,16 +123,12 @@ void unpack_weights(
{
// Build a 32-entry weight lookup table per plane
// Plane 1
vint4 tab0_plane1 = vint4::load(scb.weights + 0);
vint4 tab1_plane1 = vint4::load(scb.weights + 16);
vint tab0_plane1p, tab1_plane1p;
vtable_prepare(tab0_plane1, tab1_plane1, tab0_plane1p, tab1_plane1p);
vtable_32x8 tab_plane1;
vtable_prepare(tab_plane1, scb.weights);

// Plane 2
vint4 tab0_plane2 = vint4::load(scb.weights + 32);
vint4 tab1_plane2 = vint4::load(scb.weights + 48);
vint tab0_plane2p, tab1_plane2p;
vtable_prepare(tab0_plane2, tab1_plane2, tab0_plane2p, tab1_plane2p);
vtable_32x8 tab_plane2;
vtable_prepare(tab_plane2, scb.weights + 32);

for (unsigned int i = 0; i < bsd.texel_count; i += ASTCENC_SIMD_WIDTH)
{
Expand All @@ -153,8 +144,8 @@ void unpack_weights(
vint texel_weights(di.texel_weights_tr[j] + i);
vint texel_weights_int(di.texel_weight_contribs_int_tr[j] + i);

sum_plane1 += vtable_8bt_32bi(tab0_plane1p, tab1_plane1p, texel_weights) * texel_weights_int;
sum_plane2 += vtable_8bt_32bi(tab0_plane2p, tab1_plane2p, texel_weights) * texel_weights_int;
sum_plane1 += vtable_lookup_32bit(tab_plane1, texel_weights) * texel_weights_int;
sum_plane2 += vtable_lookup_32bit(tab_plane2, texel_weights) * texel_weights_int;
}

store(lsr<4>(sum_plane1), weights_plane1 + i);
Expand Down
19 changes: 8 additions & 11 deletions Source/astcenc_ideal_endpoints_and_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,8 @@ void compute_quantized_weights_for_decimation(
// safe data in compute_ideal_weights_for_decimation and arrays are always 64 elements
if (get_quant_level(quant_level) <= 16)
{
vint4 tab0 = vint4::load(qat.quant_to_unquant);
vint tab0p;
vtable_prepare(tab0, tab0p);
vtable_16x8 table;
vtable_prepare(table, qat.quant_to_unquant);

for (int i = 0; i < weight_count; i += ASTCENC_SIMD_WIDTH)
{
Expand All @@ -1038,8 +1037,8 @@ void compute_quantized_weights_for_decimation(
vint weightl = float_to_int(ix1);
vint weighth = min(weightl + vint(1), steps_m1);

vint ixli = vtable_8bt_32bi(tab0p, weightl);
vint ixhi = vtable_8bt_32bi(tab0p, weighth);
vint ixli = vtable_lookup_32bit(table, weightl);
vint ixhi = vtable_lookup_32bit(table, weighth);

vfloat ixl = int_to_float(ixli);
vfloat ixh = int_to_float(ixhi);
Expand All @@ -1055,10 +1054,8 @@ void compute_quantized_weights_for_decimation(
}
else
{
vint4 tab0 = vint4::load(qat.quant_to_unquant + 0);
vint4 tab1 = vint4::load(qat.quant_to_unquant + 16);
vint tab0p, tab1p;
vtable_prepare(tab0, tab1, tab0p, tab1p);
vtable_32x8 table;
vtable_prepare(table, qat.quant_to_unquant);

for (int i = 0; i < weight_count; i += ASTCENC_SIMD_WIDTH)
{
Expand All @@ -1071,8 +1068,8 @@ void compute_quantized_weights_for_decimation(
vint weightl = float_to_int(ix1);
vint weighth = min(weightl + vint(1), steps_m1);

vint ixli = vtable_8bt_32bi(tab0p, tab1p, weightl);
vint ixhi = vtable_8bt_32bi(tab0p, tab1p, weighth);
vint ixli = vtable_lookup_32bit(table, weightl);
vint ixhi = vtable_lookup_32bit(table, weighth);

vfloat ixl = int_to_float(ixli);
vfloat ixh = int_to_float(ixhi);
Expand Down
20 changes: 20 additions & 0 deletions Source/astcenc_vecmathlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
using vint = vint8;
using vmask = vmask8;

using vtable_16x8 = vtable8_16x8;
using vtable_32x8 = vtable8_32x8;
using vtable_64x8 = vtable8_64x8;

constexpr auto loada = vfloat8::loada;
constexpr auto load1 = vfloat8::load1;

Expand All @@ -111,6 +115,10 @@
using vint = vint4;
using vmask = vmask4;

using vtable_16x8 = vtable4_16x8;
using vtable_32x8 = vtable4_32x8;
using vtable_64x8 = vtable4_64x8;

constexpr auto loada = vfloat4::loada;
constexpr auto load1 = vfloat4::load1;

Expand Down Expand Up @@ -138,6 +146,10 @@
using vint = vint8;
using vmask = vmask8;

using vtable_16x8 = vtable8_16x8;
using vtable_32x8 = vtable8_32x8;
using vtable_64x8 = vtable8_64x8;

constexpr auto loada = vfloat8::loada;
constexpr auto load1 = vfloat8::load1;

Expand All @@ -153,6 +165,10 @@
using vint = vint4;
using vmask = vmask4;

using vtable_16x8 = vtable4_16x8;
using vtable_32x8 = vtable4_32x8;
using vtable_64x8 = vtable4_64x8;

constexpr auto loada = vfloat4::loada;
constexpr auto load1 = vfloat4::load1;

Expand Down Expand Up @@ -185,6 +201,10 @@
using vint = vint4;
using vmask = vmask4;

using vtable_16x8 = vtable4_16x8;
using vtable_32x8 = vtable4_32x8;
using vtable_64x8 = vtable4_64x8;

constexpr auto loada = vfloat4::loada;
constexpr auto load1 = vfloat4::load1;
#endif
Expand Down
Loading

0 comments on commit 454b6de

Please sign in to comment.