Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local attention in Hopper FAv3 #1197

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .eggs/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins.

This directory caches those eggs to prevent repeated downloads.

However, it is safe to delete this directory.

15 changes: 10 additions & 5 deletions flash_attn/bert_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,23 @@ def backward(ctx, grad_output, grad_residual):
index_first_axis_residual = IndexFirstAxisResidual.apply


def unpad_input(hidden_states, attention_mask):
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None.
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
Expand All @@ -120,6 +124,7 @@ def unpad_input(hidden_states, attention_mask):
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)


Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
Expand Down
10 changes: 7 additions & 3 deletions hopper/epilogue_bwd_sm90_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd {
Element* ptr_dV;
StridedKV const stride_dV;
int const* cu_seqlens = nullptr;
int const* seqused = nullptr;
};

// Device side kernel params
Expand All @@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd {
StridedKV const stride_dV;
TMA_dKV tma_store_dK, tma_store_dV;
int const* cu_seqlens = nullptr;
int const* seqused = nullptr;
};

static Params
Expand All @@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd {
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for dKV
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
tma_store_dK, tma_store_dV, args.cu_seqlens};
tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
}

/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
Expand Down Expand Up @@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd {
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
bool const is_varlen = params.cu_seqlens != nullptr;
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (
params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]
);

Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Expand Down Expand Up @@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd {
auto [n_block, bidh, bidb] = block_coord;
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset;
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);

Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Expand Down
5 changes: 4 additions & 1 deletion hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;

// If provided, the actual length of each k sequence.
// If provided, the actual length of each q / o sequence.
int * __restrict__ seqused_q;
// If provided, the actual length of each k / v sequence.
int * __restrict__ seqused_k;

int *__restrict__ blockmask;
Expand Down Expand Up @@ -116,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
bool is_bf16;
bool is_e4m3;
bool is_causal;
bool is_local;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
Expand Down
Loading
Loading