Triton has poor performance for older GPUs like Turing / Volta. (triton-lang/triton#2377)
CUTLASS resources:
- Discord: NVIDIA/cutlass#1087
- Useful GitHub issues:
The Flash Attention repository (https://github.com/Dao-AILab/flash-attention) is a really good example of CUTLASS in practice.
The Triton implementation spends most of its on CPU. See the profile screenshot below:
The CUDA stream (bottom) is almost entirely empty.
The above plot is for 0% sparsity. We see that for small problem sizes, the Triton implementation is terrible, most likely due to the CPU overhead. For large problem sizes, the Triton implementation doesn't match Dense most likely due to poor support of Triton for Turing GPUs.
All of the above suggests that I need to use CUDA and C++ to remove all the overhead.
Forward pass:
-
$Y = X W$ - Shape:
(M N) = (M K) (N K)
- Layout:
Row = Row x Row
- Sparse version:
Dense = Sparse x Dense
- Shape:
Backward pass:
-
$\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^T$ - Shape:
(M K) = (M N) (K N)
- Layout:
Row = Row x Col
- Sparse version:
Sparse = Dense x Dense
- Shape:
-
$\frac{\partial L}{\partial W}^T = X^T \frac{\partial L}{\partial Y}$ - Shape:
(K N) = (K M) (N M)
- Layout:
Col = Col x Col
- Sparse version:
Dense = Sparse x Dense
- Shape: