Skip to content

Commit

Permalink
[Rotary] Make sure frequency calculation is in fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 2, 2023
1 parent 9818f85 commit 62e9814
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 42 deletions.
2 changes: 1 addition & 1 deletion csrc/ft_attention/decoder_masked_multihead_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 6 additions & 0 deletions csrc/ft_attention/decoder_masked_multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ struct Multihead_attention_params_base {
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;

const T *rotary_cos = nullptr;
const T *rotary_sin = nullptr;

const int *nnz_head_idx = nullptr;
int nnz_heads = 0;
};

template<typename T, bool CROSS_ATTENTION>
Expand Down
33 changes: 26 additions & 7 deletions csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;
// The head.
const int hi = blockIdx.x;
// const int hi = blockIdx.x;
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices.
Expand Down Expand Up @@ -1061,10 +1062,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}
}
else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}
}
}
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
Expand Down Expand Up @@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if (handle_kv) {
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);

mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
if (params.rotary_cos == nullptr) {
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}

mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
if (params.rotary_cos == nullptr) {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_cos, params.rotary_sin);
}
}
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
}
Expand Down
241 changes: 236 additions & 5 deletions csrc/ft_attention/decoder_masked_multihead_attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1272,10 +1272,10 @@ inline __device__ void zero(T& dst)

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base)
{
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)};
const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)};
}

inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
Expand Down Expand Up @@ -1447,8 +1447,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
q = rotary_embedding_transform(q, coef);
}

inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
Expand Down Expand Up @@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
}
#endif // ENABLE_BF16

template <typename T>
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])};
}

// fp16 is special because we use uint16_t for reading the data, for backward compatibility.
template <>
inline __device__ float2 rotary_embedding_coefficient<uint16_t>(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return {float(reinterpret_cast<const __half*>(rotary_cos)[zid / 2]),
float(reinterpret_cast<const __half*>(rotary_sin)[zid / 2])};
}

inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
return;
}

inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
return;
}

inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}

inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}

inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}

Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.y = rotary_embedding_transform(q_.y, coef1);
}

inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}

Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}

inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}

inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}

inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
}

inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}

inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
}

inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}

#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}

inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}

inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
}

inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}

inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
}

inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16

template<typename Vec_T, typename T>
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);

Expand Down
Loading

0 comments on commit 62e9814

Please sign in to comment.