Skip to content

Commit

Permalink
有TensorCore时优化1token的attention
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 4, 2024
1 parent 10f8461 commit f5db1e7
Showing 1 changed file with 109 additions and 23 deletions.
132 changes: 109 additions & 23 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1817,32 +1817,58 @@ __global__ void FastllmHalfMatMulTransBBatchKernel(uint8_t** pointer, float alph
int input1Stride = (int)((size_t)pointer[id * 8 + 7]);

int tid = threadIdx.x;
/*
const int pera = 8, perb = 8;
__shared__ float sa[pera][128], sb[perb][128], sc[pera][perb];
for (int sta = 0; sta < n; sta += pera) {
for (int stb = 0; stb < k; stb += perb) {
for (int i = 0; i < pera; i++) {
if (sta + i < n) {
sa[i][tid] = (float)input0[(sta + i) * input0Stride + tid];
} else {
sa[i][tid] = 0;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
if (m == 128) {
int wid = tid >> 5;
int perN = 8, perK = 128;

const int BN = 8, BK = 128;
__shared__ float curC[BN][BK];
half hscale = (half)alpha;

for (int stN = 0; stN < n; stN += perN) {
int endN = min(n, stN + perN);
for (int stK = 0; stK < k; stK += perK) {
int endK = min(k, stK + perK);
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> frag_a[8];
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> frag_b[8];
wmma::fragment<wmma::accumulator, 8, 32, 16, float> frag_c;

wmma::fill_fragment(frag_c, 0.0);
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::load_matrix_sync(frag_a[j], &input0[(stN) * input0Stride + j * 16], input0Stride);
}
}
for (int i = 0; i < perb; i++) {
if (stb + i < k) {
sb[i][tid] = (float)input1[(stb + i) * input1Stride + tid];
} else {
sb[i][tid] = 0;
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::load_matrix_sync(frag_b[j], &input1[(stK + wid * 32) * input1Stride + j * 16], input1Stride);
}
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::mma_sync(frag_c, frag_a[j], frag_b[j], frag_c);
}
__syncthreads();

wmma::store_matrix_sync(&curC[0][wid * 32], frag_c, BK, wmma::mem_row_major);
__syncthreads();

if (stK + tid < endK) {
for (int i = 0; stN + i < endN; i++) {
output[(stN + i) * k + stK + tid] = (half)(curC[i][tid] * alpha);
}
}
__syncthreads();
}
__syncthreads();
__syncthreads();
}
return;
}
*/

#endif
int pera = 4, perb = 4;
half cura[4][4], curb[4][4];
float curc[4][4];
Expand Down Expand Up @@ -1999,8 +2025,69 @@ __global__ void FastllmHalfMatMulKernel(uint8_t** pointer, float alpha) {
int k = (int)((size_t)pointer[id * 8 + 5]);
int input0Stride = (int)((size_t)pointer[id * 8 + 6]);
int input1Stride = (int)((size_t)pointer[id * 8 + 7]);

int tid = threadIdx.x;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
if (k == 128) {
int wid = tid >> 5;
int perN = 8, perM = 128;
for (int i = 0; i < n; i++) {
output[i * k + tid] = (half)0;
}

__shared__ half curA[8][128];
__shared__ float curC[8][128];

for (int stN = 0; stN < n; stN += perN) {
int endN = min(stN + perN, n);
wmma::fragment<wmma::accumulator, 8, 32, 16, float> frag_c;
wmma::fill_fragment(frag_c, 0.0);

for (int stM = 0; stM < m; stM += perM) {
int endM = min(stM + perM, m);
if (stM + tid < m) {
for (int i = 0; stN + i < endN; i++) {
curA[i][tid] = input0[(stN + i) * input0Stride + stM + tid];
}
} else {
for (int i = 0; stN + i < endN; i++) {
curA[i][tid] = (half)0.0;
}
}

wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> frag_a[8];
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::row_major> frag_b[8];
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::load_matrix_sync(frag_a[j], &curA[0][16 * j], 128);
}
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::load_matrix_sync(frag_b[j], &input1[(stM + 16 * j) * input1Stride + wid * 32], input1Stride);
}
__syncthreads();

#pragma unroll
for (int j = 0; j < 8; j++) {
wmma::mma_sync(frag_c, frag_a[j], frag_b[j], frag_c);
}
__syncthreads();
}
wmma::store_matrix_sync(&curC[0][wid * 32], frag_c, 128, wmma::mem_row_major);
__syncthreads();

for (int i = 0; stN + i < endN; i++) {
output[(stN + i) * k + tid] = (half)((float)output[(stN + i) * k + tid] + (float)curC[i][tid] * alpha);
}
__syncthreads();
}
return;
}
#endif
int pera = 4, perb = 4;
float cura[4][4], curb[4][4], curc[4][4];
int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1;
Expand Down Expand Up @@ -2057,7 +2144,6 @@ __global__ void FastllmHalfMatMulKernel(uint8_t** pointer, float alpha) {
}
}
/*
int tid = threadIdx.x;
for (int i = 0; i < n; i++) {
half *curInput0 = input0 + i * input0Stride;
for (int j = tid; j < k; j += THREAD_PER_BLOCK) {
Expand Down

0 comments on commit f5db1e7

Please sign in to comment.