diff --git a/flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu b/flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu index dde683f..d3709dc 100644 --- a/flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu +++ b/flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu @@ -14,6 +14,7 @@ #define FLOAT4(value) (reinterpret_cast(&(value))[0]) #define HALF2(value) (reinterpret_cast(&(value))[0]) #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) // Load matrix to REGISTER #define LDMATRIX_X4(R0, R1, R2, R3, addr) \ @@ -62,7 +63,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( int new_dim_x = dim_x % 16; int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16); - FLOAT4(Qi[new_dim_y * 16 + new_dim_x]) = FLOAT4(Q[qkv_offset + (i * tile_size) + x]); + LDST128BITS(Qi[new_dim_y * 16 + new_dim_x]) = LDST128BITS(Q[qkv_offset + (i * tile_size) + x]); } __syncthreads(); @@ -92,7 +93,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( int new_dim_x = dim_x % 16; int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16); - FLOAT4(Kj[new_dim_y * 16 + new_dim_x]) = FLOAT4(K[qkv_offset + (j * tile_size) + x]); + LDST128BITS(Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS(K[qkv_offset + (j * tile_size) + x]); } __syncthreads(); @@ -124,7 +125,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( // Read V from global memory to shared memory for (int x = threadIdx.x * 8; x < tile_size; x += 1024) { - FLOAT4(reg[0]) = FLOAT4(V[qkv_offset + (j * tile_size) + x]); + LDST128BITS(reg[0]) = LDST128BITS(V[qkv_offset + (j * tile_size) + x]); int dim_x = x % d; int dim_y = x / d; @@ -142,10 +143,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( // adapt from https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization/blob/main/inc/fmha_i8.cuh // Softmax phase (m, l calculate) // FETCHING REGISTER - FLOAT4(reg[0]) = FLOAT4(RC[0][0]); - FLOAT4(reg[8]) = FLOAT4(RC[2][0]); - FLOAT4(reg[16]) = FLOAT4(RC[4][0]); - FLOAT4(reg[24]) = FLOAT4(RC[6][0]); + LDST128BITS(reg[0]) = LDST128BITS(RC[0][0]); + LDST128BITS(reg[8]) = LDST128BITS(RC[2][0]); + LDST128BITS(reg[16]) = LDST128BITS(RC[4][0]); + LDST128BITS(reg[24]) = LDST128BITS(RC[6][0]); // thread level reduce max #pragma unroll @@ -197,10 +198,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( } // FETCHING REGISTER for P - FLOAT4(RC[0][0]) = FLOAT4(reg[0]); - FLOAT4(RC[2][0]) = FLOAT4(reg[8]); - FLOAT4(RC[4][0]) = FLOAT4(reg[16]); - FLOAT4(RC[6][0]) = FLOAT4(reg[24]); + LDST128BITS(RC[0][0]) = LDST128BITS(reg[0]); + LDST128BITS(RC[2][0]) = LDST128BITS(reg[8]); + LDST128BITS(RC[4][0]) = LDST128BITS(reg[16]); + LDST128BITS(RC[6][0]) = LDST128BITS(reg[24]); // P @ V for (int k = 0; k < d / 16; k++) { @@ -220,7 +221,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( RD[2], RD[3]); } - FLOAT4(reg[0]) = FLOAT4(RD[0]); + LDST128BITS(reg[0]) = LDST128BITS(RD[0]); #pragma unroll for(int tc_yi = 0; tc_yi < 2; tc_yi++) { float thread_max_new = max(thread_max_old[tc_yi], thread_max[tc_yi]);