Skip to content

Commit

Permalink
[SYCL][ESIMD] BFN function implementation (intel#8708)
Browse files Browse the repository at this point in the history
API follows a similar one offered by CM.
Example: d = esimd::bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(s0, s1, s2);
  • Loading branch information
turinevgeny authored Mar 24, 2023
1 parent 4167545 commit 25d0475
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 1 deletion.
3 changes: 2 additions & 1 deletion llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,8 @@ class ESIMDIntrinDescTable {
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
{"slm_init", {"slm.init", {a(0)}}},
{"bf_cvt", {"bf.cvt", {a(0)}}},
{"tf32_cvt", {"tf32.cvt", {a(0)}}}};
{"tf32_cvt", {"tf32.cvt", {a(0)}}},
{"bfn", {"bfn", {a(0), a(1), a(2), t(0)}}}};
}

const IntrinTable &getTable() { return Table; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,19 @@ __esimd_dpasw_nosrc0(__ESIMD_DNS::vector_type_t<T1, N1> src1,
}
#endif // !__SYCL_DEVICE_ONLY__

template <uint8_t FuncControl, typename T, int N>
__ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
__esimd_bfn(__ESIMD_raw_vec_t(T, N) src0, __ESIMD_raw_vec_t(T, N) src1,
__ESIMD_raw_vec_t(T, N) src2)
#ifdef __SYCL_DEVICE_ONLY__
;
#else // !__SYCL_DEVICE_ONLY__
{
__ESIMD_UNSUPPORTED_ON_HOST;
return __ESIMD_DNS::vector_type_t<T, N>();
}
#endif // !__SYCL_DEVICE_ONLY__

#undef __ESIMD_raw_vec_t
#undef __ESIMD_cpp_vec_t

Expand Down
100 changes: 100 additions & 0 deletions sycl/include/sycl/ext/intel/experimental/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,106 @@ __ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(
}
/// @} sycl_esimd_systolic_array_api

/// @addtogroup sycl_esimd_logical
/// @{

/// This enum is used to encode all possible logical operations performed
/// on the 3 input operands. It is used as a template argument of the bfn()
/// function.
/// Example: d = bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(s0, s1, s2);
enum class bfn_t : uint8_t { x = 0xAA, y = 0xCC, z = 0xF0 };

static constexpr bfn_t operator~(bfn_t x) {
uint8_t val = static_cast<uint8_t>(x);
uint8_t res = ~val;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator|(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 | arg1;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator&(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 & arg1;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator^(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 ^ arg1;
return static_cast<bfn_t>(res);
}

/// Performs binary function computation with three vector operands.
/// @tparam FuncControl boolean function control expressed with bfn_t
/// enum values.
/// @tparam T type of the input vector element.
/// @tparam N size of the input vector.
/// @param s0 First boolean function argument.
/// @param s1 Second boolean function argument.
/// @param s2 Third boolean function argument.
template <bfn_t FuncControl, typename T, int N>
__ESIMD_API std::enable_if_t<std::is_integral_v<T>, __ESIMD_NS::simd<T, N>>
bfn(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T, N> src1,
__ESIMD_NS::simd<T, N> src2) {
if constexpr ((sizeof(T) == 8) || ((sizeof(T) == 1) && (N % 4 == 0)) ||
((sizeof(T) == 2) && (N % 2 == 0))) {
// Bitcast Nx8-byte vectors to 2xN vectors of 4-byte integers.
// Bitcast Nx1-byte vectors to N/4 vectors of 4-byte integers.
// Bitcast Nx2-byte vectors to N/2 vectors of 4-byte integers.
auto Result = __ESIMD_ENS::bfn<FuncControl>(
src0.template bit_cast_view<int32_t>().read(),
src1.template bit_cast_view<int32_t>().read(),
src2.template bit_cast_view<int32_t>().read());
return Result.template bit_cast_view<T>();
} else if constexpr (sizeof(T) == 2 || sizeof(T) == 4) {
constexpr uint8_t FC = static_cast<uint8_t>(FuncControl);
return __esimd_bfn<FC, T, N>(src0.data(), src1.data(), src2.data());
} else if constexpr (N % 2 == 0) {
// Bitcast Nx1-byte vectors (N is even) to N/2 vectors of 2-byte integers.
auto Result = __ESIMD_ENS::bfn<FuncControl>(
src0.template bit_cast_view<int16_t>().read(),
src1.template bit_cast_view<int16_t>().read(),
src2.template bit_cast_view<int16_t>().read());
return Result.template bit_cast_view<T>();
} else {
// Odd number of 1-byte elements.
__ESIMD_NS::simd<T, N + 1> Src0, Src1, Src2;
Src0.template select<N, 1>() = src0;
Src1.template select<N, 1>() = src1;
Src2.template select<N, 1>() = src2;
auto Result = __ESIMD_ENS::bfn<FuncControl>(Src0, Src1, Src2);
return Result.template select<N, 1>();
}
}

/// Performs binary function computation with three scalar operands.
/// @tparam FuncControl boolean function control expressed with bfn_t enum
/// values.
/// @tparam T type of the input vector element.
/// @param s0 First boolean function argument.
/// @param s1 Second boolean function argument.
/// @param s2 Third boolean function argument.
template <bfn_t FuncControl, typename T>
ESIMD_NODEBUG ESIMD_INLINE std::enable_if_t<
__ESIMD_DNS::is_esimd_scalar<T>::value && std::is_integral_v<T>, T>
bfn(T src0, T src1, T src2) {
__ESIMD_NS::simd<T, 1> Src0 = src0;
__ESIMD_NS::simd<T, 1> Src1 = src1;
__ESIMD_NS::simd<T, 1> Src2 = src2;
__ESIMD_NS::simd<T, 1> Result =
esimd::bfn<FuncControl, T, 1>(Src0, Src1, Src2);
return Result[0];
}

/// @} sycl_esimd_logical

} // namespace ext::intel::experimental::esimd
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
8 changes: 8 additions & 0 deletions sycl/test/esimd/intrins_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,12 @@ SYCL_EXTERNAL void test_math_intrins() SYCL_ESIMD_FUNCTION {
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x float> @llvm.genx.ieee.sqrt.v8f32(<8 x float> %{{[a-zA-Z0-9.]+}})
use(y);
}
{
vec<int, 8> x0 = get8i();
vec<int, 8> x1 = get8i();
vec<int, 8> x2 = get8i();
auto res = __esimd_bfn<0xff, int, 8>(x0, x1, x2);
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x i32> @llvm.genx.bfn.v8i32.v8i32(<8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, i8 -1)
use(res);
}
}
11 changes: 11 additions & 0 deletions sycl/test/esimd/math_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using namespace sycl;
using namespace sycl::ext::intel;
using namespace sycl::ext::intel::esimd;
using namespace sycl::ext::intel::experimental::esimd;

// Math sin,cos,log,exp functions are translated into scalar __spirv_ocl_ calls
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16> sycl_math(simd<float, 16> x) {
Expand Down Expand Up @@ -52,3 +53,13 @@ esimd_math_emu(simd<float, 16> x) {
v = esimd::exp(v);
return v;
}

// Logical BNF function from esimd namespace is translated into __esimd_ calls,
// which later translate into GenX intrinsics.
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<int, 16>
esimd_bfn(simd<int, 16> x, simd<int, 16> y, simd<int, 16> z) {
simd<int, 16> v =
experimental::esimd::bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(x, y, z);
//CHECK: call spir_func noundef <16 x i32> @_Z11__esimd_bfn
return v;
}

0 comments on commit 25d0475

Please sign in to comment.