diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu index 5b5966a31..6e5d5a27b 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ b/csrc/ft_attention/decoder_masked_multihead_attention.cu @@ -34,7 +34,7 @@ if (smem_sz >= 48 * 1024) { \ cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ } \ - dim3 grid(params.num_heads, params.batch_size); \ + dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h index 590b02ccb..e9dcd376b 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ b/csrc/ft_attention/decoder_masked_multihead_attention.h @@ -113,6 +113,12 @@ struct Multihead_attention_params_base { const float* qkv_scale_out = nullptr; const float* attention_out_scale = nullptr; int int8_mode = 0; + + const T *rotary_cos = nullptr; + const T *rotary_sin = nullptr; + + const int *nnz_head_idx = nullptr; + int nnz_heads = 0; }; template diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index 8da5929da..f99b818f2 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params 0 && !params.neox_rotary_style) { if (handle_kv) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + if (params.rotary_cos == nullptr) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } else { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin); + } } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + if (params.rotary_cos == nullptr) { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin); + } } } else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { @@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= rot_embed_dim) { return; @@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, } #endif // ENABLE_BF16 +template +inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin) +{ + // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). + // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. + return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])}; +} + +// fp16 is special because we use uint16_t for reading the data, for backward compatibility. +template <> +inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). + // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. + return {float(reinterpret_cast(rotary_cos)[zid / 2]), + float(reinterpret_cast(rotary_sin)[zid / 2])}; +} + +inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q_.x = rotary_embedding_transform(q_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q_.y = rotary_embedding_transform(q_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} + +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + template __device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp index de19a4da1..3b6ad55ce 100644 --- a/csrc/ft_attention/ft_attention.cpp +++ b/csrc/ft_attention/ft_attention.cpp @@ -57,13 +57,17 @@ void set_params(Masked_multihead_attention_params ¶ms, const float rotary_base, const bool neox_rotary_style, const int qkv_batch_stride, + const int nnz_heads, T *q_ptr, T *k_ptr, T *v_ptr, T *k_cache_ptr, T *v_cache_ptr, int *length_per_sample, - T *out_ptr) { + T *rotary_cos, + T *rotary_sin, + T *out_ptr, + int *nnz_head_idx) { // Reset the parameters memset(¶ms, 0, sizeof(params)); params.q = q_ptr; @@ -81,6 +85,7 @@ void set_params(Masked_multihead_attention_params ¶ms, params.beam_width = 1; params.memory_max_len = memory_max_seqlen; params.num_heads = nheads; + params.nnz_heads = nnz_heads; params.hidden_size_per_head = headdim; params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_base = rotary_base; @@ -99,6 +104,9 @@ void set_params(Masked_multihead_attention_params ¶ms, params.finished = nullptr; params.memory_length_per_sample = nullptr; params.length_per_sample = length_per_sample; + params.rotary_cos = rotary_cos; + params.rotary_sin = rotary_sin; + params.nnz_head_idx = nnz_head_idx; } torch::Tensor single_query_attention(const torch::Tensor q, @@ -107,8 +115,11 @@ torch::Tensor single_query_attention(const torch::Tensor q, torch::Tensor k_cache, torch::Tensor v_cache, c10::optional length_per_sample_, + c10::optional rotary_cos_, + c10::optional rotary_sin_, + c10::optional nnz_head_idx_, const int timestep, - const int rotary_embedding_dim = 0, + int rotary_embedding_dim = 0, const float rotary_base = 10000.0f, const bool neox_rotary_style=true) { CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); @@ -116,6 +127,9 @@ torch::Tensor single_query_attention(const torch::Tensor q, int nheads = v_cache.size(1); int memory_max_seqlen = v_cache.size(2); int headdim = v_cache.size(3); + auto input_type = q.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + CHECK_SHAPE(q, batch_size, nheads, headdim); CHECK_SHAPE(k, batch_size, nheads, headdim); CHECK_SHAPE(v, batch_size, nheads, headdim); @@ -129,6 +143,12 @@ torch::Tensor single_query_attention(const torch::Tensor q, TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); + TORCH_CHECK(q.scalar_type() == input_type); + TORCH_CHECK(k.scalar_type() == input_type); + TORCH_CHECK(v.scalar_type() == input_type); + TORCH_CHECK(k_cache.scalar_type() == input_type); + TORCH_CHECK(v_cache.scalar_type() == input_type); + if (length_per_sample_.has_value()) { auto length_per_sample = length_per_sample_.value(); CHECK_DEVICE(length_per_sample); @@ -137,6 +157,32 @@ torch::Tensor single_query_attention(const torch::Tensor q, TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); } + if (rotary_cos_.has_value()) { + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + int rotary_seqlen = rotary_cos.size(0); + rotary_embedding_dim = rotary_cos.size(1) * 2; + CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2); + CHECK_CONTIGUOUS(rotary_cos); + TORCH_CHECK(rotary_cos.scalar_type() == input_type); + + TORCH_CHECK(rotary_sin_.has_value()); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2); + CHECK_CONTIGUOUS(rotary_sin); + TORCH_CHECK(rotary_sin.scalar_type() == input_type); + } + + if (nnz_head_idx_.has_value()) { + auto nnz_head_idx = nnz_head_idx_.value(); + CHECK_DEVICE(nnz_head_idx); + int nnz_heads = nnz_head_idx.size(0); + CHECK_SHAPE(nnz_head_idx, nnz_heads); + CHECK_CONTIGUOUS(nnz_head_idx); + TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32); + } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -148,6 +194,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, Masked_multihead_attention_params params; set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), + nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), @@ -155,7 +202,13 @@ torch::Tensor single_query_attention(const torch::Tensor q, reinterpret_cast(v_cache.data_ptr()), length_per_sample_.has_value() ? length_per_sample_.value().data_ptr() : nullptr, - reinterpret_cast(out.data_ptr())); + rotary_cos_.has_value() + ? reinterpret_cast(rotary_cos_.value().data_ptr()) : nullptr, + rotary_sin_.has_value() + ? reinterpret_cast(rotary_sin_.value().data_ptr()) : nullptr, + reinterpret_cast(out.data_ptr()), + nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr() : nullptr + ); auto stream = at::cuda::getCurrentCUDAStream(); masked_multihead_attention(params, stream); }); @@ -165,6 +218,8 @@ torch::Tensor single_query_attention(const torch::Tensor q, PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_query_attention", &single_query_attention, "Attention with a single query", py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, + py::arg("length_per_sample_"), py::arg("rotary_cos_"), + py::arg("rotary_sin_"), py::arg("nnz_head_idx_"), + py::arg("timestep"), py::arg("rotary_embedding_dim")=0, py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); } diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 44ceab5e3..0cb56328d 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -169,16 +169,28 @@ class RotaryEmbedding(torch.nn.Module): Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ - def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None): + def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, + pos_idx_in_fp32=True, device=None): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. """ super().__init__() + self.dim = dim self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, - dtype=torch.float32) / dim)) + inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq) self.interleaved = interleaved self.scale_base = scale_base @@ -192,31 +204,48 @@ def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, d self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x, seqlen_offset=0): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) - """ - seqlen = x.shape[1] + seqlen_offset + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, + dtype=torch.float32) / self.dim)) + + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype): + if (seqlen > self._seq_len_cached or self._cos_cached.device != device + or self._cos_cached.dtype != dtype): self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + freqs = torch.outer(t, inv_freq) if self.scale is None: - self._cos_cached = torch.cos(freqs).to(x.dtype) - self._sin_cached = torch.sin(freqs).to(x.dtype) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) else: power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base) scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -224,7 +253,7 @@ def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tens seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ - self._update_cos_sin_cache(qkv, seqlen_offset) + self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if self.scale is None: return apply_rotary_emb_qkv_( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 9d6866833..684935dac 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -515,8 +515,13 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), - k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], - lengths_per_sample, inference_params.sequence_len_offset, + k_cache[batch_start:batch_end], + v_cache[batch_start:batch_end], + lengths_per_sample, + None, # rotary_cos_ + None, # rotary_sin_ + None, # nnz_head_idx + inference_params.sequence_len_offset, self.rotary_emb_dim, rotary_emb_base, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True @@ -637,8 +642,13 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), - k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], - lengths_per_sample, inference_params.sequence_len_offset, + k_cache[batch_start:batch_end], + v_cache[batch_start:batch_end], + lengths_per_sample, + None, # rotary_cos_ + None, # rotary_sin_ + None, # nnz_head_idx + inference_params.sequence_len_offset, self.rotary_emb_dim, rotary_emb_base, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index f798122fd..7f99ef198 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -267,10 +267,11 @@ def test_llama_generation(model_name): del model hf_error = (logits_hf - logits_ref).abs().max().item() - assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error print(f'HF fp16 logits max diff: {hf_error}') print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') - assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') + + assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error + assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits)