Skip to content

Commit

Permalink
doc masking
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy committed Sep 27, 2024
1 parent 53a4f34 commit c78c34b
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 39 deletions.
3 changes: 3 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ struct Flash_fwd_params : public Qkv_params {
float * __restrict__ descale_q_ptr;
float * __restrict__ descale_k_ptr;
float * __restrict__ descale_v_ptr;

// Whether to optimize for document masking.
bool optimize_for_doc_masking;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
28 changes: 21 additions & 7 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void set_params_fprop(Flash_fwd_params &params,
int window_size_left,
int window_size_right,
bool seqlenq_ngroups_swapped=false,
bool unpadded_lse=false) {
bool unpadded_lse=false,
bool optimize_for_doc_masking=false) {

// Reset the parameters
params = {};
Expand Down Expand Up @@ -154,6 +155,7 @@ void set_params_fprop(Flash_fwd_params &params,
#endif

params.unpadded_lse = unpadded_lse;
params.optimize_for_doc_masking = optimize_for_doc_masking;
}

void set_params_dgrad(Flash_bwd_params &params,
Expand Down Expand Up @@ -189,7 +191,10 @@ void set_params_dgrad(Flash_bwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool deterministic) {
bool deterministic,
bool seqlenq_ngroups_swapped=false,
bool unpadded_lse=false,
bool optimize_for_doc_masking=false) {

set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
Expand All @@ -203,7 +208,10 @@ void set_params_dgrad(Flash_bwd_params &params,
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
seqlenq_ngroups_swapped,
unpadded_lse,
optimize_for_doc_masking);

// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
Expand Down Expand Up @@ -448,7 +456,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right) {
int window_size_right,
bool optimize_for_doc_masking) {

auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
Expand Down Expand Up @@ -567,7 +576,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_left,
window_size_right,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/true);
/*unpadded_lse=*/true,
/*optimize_for_doc_masking=*/optimize_for_doc_masking);
params.total_q = total_q;
params.total_k = total_k;

Expand Down Expand Up @@ -823,7 +833,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
const bool is_causal,
int window_size_left,
int window_size_right,
const bool deterministic) {
const bool deterministic,
const bool optimize_for_doc_masking) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -989,7 +1000,10 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
softmax_scale,
/*window_size_left=*/window_size_left,
/*window_size_right=*/window_size_right,
deterministic);
deterministic,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/true,
optimize_for_doc_masking);
params.total_q = total_q;
params.total_k = total_k;
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
Expand Down
10 changes: 10 additions & 0 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _flash_attn_varlen_forward(
window_size=(-1, -1),
seqused_q=None,
seqused_k=None,
optimize_for_doc_masking=False,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Expand All @@ -98,6 +99,7 @@ def _flash_attn_varlen_forward(
causal,
window_size[0],
window_size[1],
optimize_for_doc_masking,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand All @@ -124,6 +126,7 @@ def _flash_attn_varlen_backward(
deterministic=False,
seqused_q=None,
seqused_k=None,
optimize_for_doc_masking=False,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
Expand Down Expand Up @@ -155,6 +158,7 @@ def _flash_attn_varlen_backward(
window_size[0],
window_size[1],
deterministic,
optimize_for_doc_masking,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -238,6 +242,7 @@ def forward(
deterministic=False,
seqused_q=None,
seqused_k=None,
optimize_for_doc_masking=False,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -254,6 +259,7 @@ def forward(
window_size=window_size,
seqused_q=seqused_q,
seqused_k=seqused_k,
optimize_for_doc_masking=optimize_for_doc_masking,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
Expand All @@ -265,6 +271,7 @@ def forward(
ctx.causal = causal
ctx.window_size = window_size
ctx.deterministic = deterministic
ctx.optimize_for_doc_masking = optimize_for_doc_masking
return out, softmax_lse

@staticmethod
Expand All @@ -291,6 +298,7 @@ def backward(ctx, dout, *args):
ctx.deterministic,
seqused_q,
seqused_k,
ctx.optimize_for_doc_masking,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down Expand Up @@ -389,6 +397,7 @@ def flash_attn_varlen_func(
deterministic=False,
seqused_q=None,
seqused_k=None,
optimize_for_doc_masking=False,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -444,4 +453,5 @@ def flash_attn_varlen_func(
deterministic,
seqused_q,
seqused_k,
optimize_for_doc_masking,
)
18 changes: 12 additions & 6 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,18 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
int work_idx = 0;

TileScheduler scheduler(&shared_storage.tile_count_semaphore);
for (auto work_tile_info = scheduler.get_initial_work();
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;

seqlen_traits_q.init(bidb);
bool within_max_seqlen = seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
if (!within_max_seqlen) {
work_tile_info.move_to_next_batch();
}
continue;
}
const int n_block_max = collective_mainloop.get_n_block_max(
Expand Down Expand Up @@ -154,7 +157,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (auto work_tile_info = scheduler.get_initial_work();
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Expand All @@ -164,9 +167,12 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;

seqlen_traits_q.init(bidb);
bool within_max_seqlen = seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
if (!within_max_seqlen) {
work_tile_info.move_to_next_batch();
}
continue;
}
const int n_block_max = collective_mainloop.get_n_block_max(
Expand Down Expand Up @@ -296,7 +302,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
int work_idx = 0;

TileScheduler scheduler(&shared_storage.tile_count_semaphore);
for (auto work_tile_info = scheduler.get_initial_work();
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
Expand Down Expand Up @@ -345,7 +351,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (auto work_tile_info = scheduler.get_initial_work();
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Expand Down
39 changes: 25 additions & 14 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
using Scheduler = std::conditional_t<
using RegularScheduler = std::conditional_t<
Seqlen_traits::kUseVarSeqLen || Is_local,
flash::SingleTileScheduler,
std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, Kernel_traits::NumProducerThreads>
>>;
// using Scheduler = flash::SingleTileScheduler;
using Scheduler = std::conditional_t<
Kernel_traits::kUseDocMasking,
flash::DocMaskingStaticPersistentTileScheduler,
RegularScheduler
>;
Seqlen_traits seqlen_traits_q(
params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
Seqlen_traits seqlen_traits_k(
Expand Down Expand Up @@ -80,7 +85,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {

int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, seqlen_traits_q.max_seq_len, params.tile_count_semaphore};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);

// Get the ptr to kernel function.
Expand Down Expand Up @@ -120,10 +125,12 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(params.is_local, Is_local, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
BOOL_SWITCH(params.optimize_for_doc_masking, Use_doc_masking, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T, Use_doc_masking>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});
});
});
Expand All @@ -137,10 +144,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
BOOL_SWITCH(params.optimize_for_doc_masking, Use_doc_masking, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T, Use_doc_masking>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});
});
});
Expand All @@ -155,10 +164,12 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
BOOL_SWITCH(params.optimize_for_doc_masking, Use_doc_masking, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T, Use_doc_masking>,
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});
});
});
Expand Down
8 changes: 6 additions & 2 deletions hopper/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ struct SharedStorageQKVOVt {

// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::half_t>
int kClusterM_ = 1, typename elem_type=cutlass::half_t, bool kUseDocMasking_=false>
struct Flash_fwd_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using OutputType = elem_type;
using index_t = int64_t;

static constexpr bool kUseDocMasking = kUseDocMasking_;

// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
Expand Down Expand Up @@ -141,14 +143,16 @@ struct Flash_fwd_kernel_traits {

// Traits struct for fp8 kernel with in-kernel transpose
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t, bool kUseDocMasking_ = false>
struct Flash_fwd_kernel_traits_fp8 {
using Element = elem_type;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using ElementAccum = float;
using OutputType = cutlass::half_t;
using index_t = int64_t;

static constexpr bool kUseDocMasking = kUseDocMasking_;

// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
Expand Down
10 changes: 7 additions & 3 deletions hopper/seq_len.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ template <bool UseVarSeqLen> class SeqLenTraits {
int *seq_used = nullptr;
// seq len of the current batch.
int actual_seq_len = -1;
// max seq len per batch.
int max_seq_len = -1;

// Whether this is for fixed-seq-len or var-seq-len.
static constexpr bool kUseVarSeqLen = UseVarSeqLen;
Expand Down Expand Up @@ -53,7 +55,8 @@ template <bool UseVarSeqLen> class SeqLenTraits {

CUTLASS_HOST SeqLenTraits(
int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used),
actual_seq_len(max_seq_len), max_seq_len(max_seq_len) {}

// Returns the layout of a tensor in MKHB format in global memory.
// padded: only useful for var-seq-len for dq_accum and softmax_d.
Expand All @@ -75,7 +78,7 @@ template <bool UseVarSeqLen> class SeqLenTraits {
make_stride(int64_t(h * m), int64_t(m), cute::_1()));
}

CUTLASS_DEVICE void init(int bidb) {}
CUTLASS_DEVICE bool init(int bidb) {return true;}

template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
Expand Down Expand Up @@ -124,9 +127,10 @@ CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
}

template <>
CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
CUTLASS_DEVICE bool VarSeqLenTraits::init(int bidb) {
actual_seq_len =
seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
return cu_seq_len[bidb] < max_seq_len;
}

template <>
Expand Down
Loading

0 comments on commit c78c34b

Please sign in to comment.