Skip to content

Commit

Permalink
[FlashAttention] replace FLOAT4 with LDST128BITS macro (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth authored Sep 21, 2024
1 parent 068e6fe commit 4be041f
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])

// Load matrix to REGISTER
#define LDMATRIX_X4(R0, R1, R2, R3, addr) \
Expand Down Expand Up @@ -62,7 +63,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
int new_dim_x = dim_x % 16;
int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16);

FLOAT4(Qi[new_dim_y * 16 + new_dim_x]) = FLOAT4(Q[qkv_offset + (i * tile_size) + x]);
LDST128BITS(Qi[new_dim_y * 16 + new_dim_x]) = LDST128BITS(Q[qkv_offset + (i * tile_size) + x]);
}
__syncthreads();

Expand Down Expand Up @@ -92,7 +93,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
int new_dim_x = dim_x % 16;
int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16);

FLOAT4(Kj[new_dim_y * 16 + new_dim_x]) = FLOAT4(K[qkv_offset + (j * tile_size) + x]);
LDST128BITS(Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS(K[qkv_offset + (j * tile_size) + x]);
}
__syncthreads();

Expand Down Expand Up @@ -124,7 +125,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(

// Read V from global memory to shared memory
for (int x = threadIdx.x * 8; x < tile_size; x += 1024) {
FLOAT4(reg[0]) = FLOAT4(V[qkv_offset + (j * tile_size) + x]);
LDST128BITS(reg[0]) = LDST128BITS(V[qkv_offset + (j * tile_size) + x]);

int dim_x = x % d;
int dim_y = x / d;
Expand All @@ -142,10 +143,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
// adapt from https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization/blob/main/inc/fmha_i8.cuh
// Softmax phase (m, l calculate)
// FETCHING REGISTER
FLOAT4(reg[0]) = FLOAT4(RC[0][0]);
FLOAT4(reg[8]) = FLOAT4(RC[2][0]);
FLOAT4(reg[16]) = FLOAT4(RC[4][0]);
FLOAT4(reg[24]) = FLOAT4(RC[6][0]);
LDST128BITS(reg[0]) = LDST128BITS(RC[0][0]);
LDST128BITS(reg[8]) = LDST128BITS(RC[2][0]);
LDST128BITS(reg[16]) = LDST128BITS(RC[4][0]);
LDST128BITS(reg[24]) = LDST128BITS(RC[6][0]);

// thread level reduce max
#pragma unroll
Expand Down Expand Up @@ -197,10 +198,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
}

// FETCHING REGISTER for P
FLOAT4(RC[0][0]) = FLOAT4(reg[0]);
FLOAT4(RC[2][0]) = FLOAT4(reg[8]);
FLOAT4(RC[4][0]) = FLOAT4(reg[16]);
FLOAT4(RC[6][0]) = FLOAT4(reg[24]);
LDST128BITS(RC[0][0]) = LDST128BITS(reg[0]);
LDST128BITS(RC[2][0]) = LDST128BITS(reg[8]);
LDST128BITS(RC[4][0]) = LDST128BITS(reg[16]);
LDST128BITS(RC[6][0]) = LDST128BITS(reg[24]);

// P @ V
for (int k = 0; k < d / 16; k++) {
Expand All @@ -220,7 +221,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
RD[2], RD[3]);
}

FLOAT4(reg[0]) = FLOAT4(RD[0]);
LDST128BITS(reg[0]) = LDST128BITS(RD[0]);
#pragma unroll
for(int tc_yi = 0; tc_yi < 2; tc_yi++) {
float thread_max_new = max(thread_max_old[tc_yi], thread_max[tc_yi]);
Expand Down

0 comments on commit 4be041f

Please sign in to comment.