Skip to content

Commit

Permalink
[PyTorch] Port ExecuTorch bfdot improvement back to ATen BlasKernel, …
Browse files Browse the repository at this point in the history
…Try #2 (pytorch#137377)

ExecuTorch's fork of BlasKernel.cpp grew bfdot support, complete with demonstration that it helps. Port it back to PyTorch. First attempt was pytorch#136331 .

Differential Revision: [D63923166](https://our.internmc.facebook.com/intern/diff/D63923166/)
Pull Request resolved: pytorch#137377
Approved by: https://github.com/malfet
  • Loading branch information
swolchok authored and pytorchmergebot committed Oct 10, 2024
1 parent 080f02a commit 9c12198
Showing 1 changed file with 187 additions and 55 deletions.
242 changes: 187 additions & 55 deletions aten/src/ATen/native/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <arm_neon.h>
#include <cpuinfo.h>
#endif

C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
Expand Down Expand Up @@ -301,7 +302,7 @@ static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationS
static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift;
static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister);

static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
static inline float reduce(float16x8_t x[kF16RegistersPerIteration]) {
int offset = kF16RegistersPerIteration;
c10::ForcedUnroll<kF16RegistersPerIterationShift>{}([&offset, &x](auto idx) {
offset /= 2;
Expand All @@ -311,7 +312,7 @@ static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
});
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
return (double)vaddvq_f32(vaddq_f32(t0, t1));
return vaddvq_f32(vaddq_f32(t0, t1));
}

static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
Expand All @@ -333,12 +334,12 @@ static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, in
sum[k] = f16_fma(sum[k], temp_x, temp_a);
}
}
auto reducedSum = reduce(sum);
auto reduced_sum = reduce(sum);

for (int j = len_aligned; j < len; ++j) {
reducedSum += x[j] * a[j];
reduced_sum += x[j] * a[j];
}
return reducedSum;
return reduced_sum;
}

// Rather than unrolling to process multiple rows (transposed columns)
Expand All @@ -352,7 +353,7 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n,
});
}

#endif
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC

static inline float reduce(float32x4_t x) {
auto sum = vpaddq_f32(x, x);
Expand Down Expand Up @@ -412,7 +413,7 @@ static constexpr auto kF32RegistersPerIterationShift = 3;
static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister);
static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);

static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
static inline float reduce(float32x4_t x[kF32RegistersPerIteration]) {
int offset = kF32RegistersPerIteration;
c10::ForcedUnroll<kF32RegistersPerIterationShift>{}([&offset, &x](auto idx) {
offset /= 2;
Expand All @@ -423,7 +424,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
return vaddvq_f32(x[0]);
}

static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
const float16_t* vec1,
const float16_t* vec2,
float32x4_t sum[kF32RegistersPerIteration],
Expand All @@ -436,86 +437,217 @@ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2);
}

static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
const float16_t* vec1,
const float16_t* vec2,
float32x4_t* tailSum,
int idx) {
static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
const float16_t* vec1,
const float16_t* vec2,
float32x4_t* tail_sum,
int idx) {
const auto temp_vec1 = vld1_f16(&vec1[idx]);
const auto temp_vec2 = vld1_f16(&vec2[idx]);
*tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2);
*tail_sum = f32_fma_f16(*tail_sum, temp_vec1, temp_vec2);
}

static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
static float32x4_t to_bfloat16(uint16x4_t u16) {
int32x4_t shift = vdupq_n_s32(16);
return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
}

static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
static float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
}

static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
// TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16
// Load a pair of f32 registers at a time.
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
#if defined(__clang__) && __clang_major__ > 15
// https://godbolt.org/z/z8P4Yncra
#define COMPILER_SUPPORTS_BF16_TARGET 1
#elif !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
// https://gcc.gnu.org/gcc-10/changes.html
// https://godbolt.org/z/cdGG7vn8o
#define COMPILER_SUPPORTS_BF16_TARGET 1
#else
#define COMPILER_SUPPORTS_BF16_TARGET 0
#endif

#if COMPILER_SUPPORTS_BF16_TARGET
#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16")))

TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE float32x4_t
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
return vbfdotq_f32(a, b, c);
}

sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2));
TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
}

static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
float32x4_t* tailSum,
int idx) {
// See NOTE [GCC code duplication] below for why we have _bfdot and
// _no_bfdot versions of
// dot_with_fp32_arith_vectorized_tail_inner_loop.
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
static void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
float32x4_t* tail_sum,
int idx) {
const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
*tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2);
}

template <typename T>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
#else
#define TARGET_ARM_BF16_ATTRIBUTE
#endif // COMPILER_SUPPORTS_BF16_TARGET

static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
}

static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
float32x4_t* tail_sum,
int idx) {
const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
*tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2);
}

namespace {
#if COMPILER_SUPPORTS_BF16_TARGET
template <int n>
struct ForcedUnrollTargetBFloat16 {
template <typename Func>
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
ForcedUnrollTargetBFloat16<n - 1>{}(f);
f(n - 1);
}
};

template <>
struct ForcedUnrollTargetBFloat16<1> {
template <typename Func>
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
f(0);
}
};

C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto
dot_with_fp32_arith_main_loop_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) {
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k)
C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);

// First-tier tail fixup: make sure we handle workloads that can
// benefit from vectorization, but don't fit into our fully unrolled
// loop above.
float32x4_t tailSum = vdupq_n_f32(0);
const auto len_aligned_4 = len & ~3;
for (int j = len_aligned; j < len_aligned_4; j += 4) {
dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
}
auto reducedTail = vpaddq_f32(tailSum, tailSum);
reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
return reduce(sum);
}
#endif // COMPILER_SUPPORTS_BF16_TARGET

// Second-tier tail fixup: handle all workloads.
for (int j = len_aligned_4; j < len; ++j) {
reducedSum += vec1[j] * vec2[j];
template <typename T>
C10_ALWAYS_INLINE auto
dot_with_fp32_arith_main_loop_no_bfdot(
const T* vec1,
const T* vec2,
int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k);
});
}
return reducedSum;
return reduce(sum);
}

// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not
// allow inlining a non-bf16-specific function into a bf16-specific
// function. We can work around this by duplicating the code into the
// bfdot and non-bfdot callsites. The code is in this macro to avoid
// actual copy/paste.
#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \
/* First-tier tail fixup: make sure we handle workloads that can */ \
/* benefit from vectorization, but don't fit into our fully unrolled */ \
/* loop above. */ \
float32x4_t tail_sum = vdupq_n_f32(0); \
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \
const auto len_aligned_4 = len & ~3; \
for (int j = len_aligned; j < len_aligned_4; j += 4) { \
dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \
} \
auto reduced_tail = vpaddq_f32(tail_sum, tail_sum); \
reduced_sum += vgetq_lane_f32(vpaddq_f32(reduced_tail, reduced_tail), 0); \
\
/* Second-tier tail fixup: handle all workloads. */ \
for (int j = len_aligned_4; j < len; ++j) { \
reduced_sum += vec1[j] * vec2[j]; \
} \
return reduced_sum

#if COMPILER_SUPPORTS_BF16_TARGET
TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len);
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot);
}
#endif // COMPILER_SUPPORTS_BF16_TARGET

template <typename T>
C10_ALWAYS_INLINE float
dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len);
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot);
}
#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY
} // namespace

float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
}

float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
#if COMPILER_SUPPORTS_BF16_TARGET
if (cpuinfo_has_arm_bf16()) {
return dot_with_fp32_arith_bfdot(vec1, vec2, len);
} else
#endif
{
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
}
}

// On my Apple M1 Macbook (which is ARM v8.5 and thus has the
Expand Down

0 comments on commit 9c12198

Please sign in to comment.