Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shashank/seq id flash attn #738

Merged
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
04dd334
Merge pull request #1 from mosaicml/main
ShashankMosaicML Oct 9, 2023
87b2fdc
Merge pull request #8 from mosaicml/main
ShashankMosaicML Oct 27, 2023
c9a42e4
Merge pull request #12 from mosaicml/main
ShashankMosaicML Nov 6, 2023
ddea9ee
Merge branch 'mosaicml:main' into main
ShashankMosaicML Nov 6, 2023
0bcd8ee
Merge pull request #13 from mosaicml/main
ShashankMosaicML Nov 8, 2023
f209b58
Merge pull request #14 from mosaicml/main
ShashankMosaicML Nov 14, 2023
879edb2
..
ShashankMosaicML Nov 14, 2023
66d85f1
..
ShashankMosaicML Nov 14, 2023
c1c2fbd
..
ShashankMosaicML Nov 15, 2023
e8b9381
..
ShashankMosaicML Nov 15, 2023
57502f4
..
ShashankMosaicML Nov 15, 2023
3041bf6
..
ShashankMosaicML Nov 15, 2023
9b305df
..
ShashankMosaicML Nov 15, 2023
b5a3c1f
..
ShashankMosaicML Nov 15, 2023
2e60014
..
ShashankMosaicML Nov 15, 2023
bf946c7
..
ShashankMosaicML Nov 15, 2023
e99780e
..
ShashankMosaicML Nov 15, 2023
897edbc
..
ShashankMosaicML Nov 15, 2023
88cf6d3
..
ShashankMosaicML Nov 15, 2023
c0a4d97
..
ShashankMosaicML Nov 15, 2023
ec4378d
Merge pull request #15 from mosaicml/main
ShashankMosaicML Nov 15, 2023
ad36ecf
merged and resolved conflicts
ShashankMosaicML Nov 15, 2023
c09d5c9
..
ShashankMosaicML Nov 15, 2023
efc4f4e
..
ShashankMosaicML Nov 15, 2023
4940e70
..
ShashankMosaicML Nov 15, 2023
5e3ccf9
..
ShashankMosaicML Nov 16, 2023
84fa710
Update llmfoundry/models/layers/attention.py
ShashankMosaicML Nov 17, 2023
a560f31
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Nov 17, 2023
a70f05e
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Nov 17, 2023
42a541d
..
ShashankMosaicML Nov 17, 2023
3d1d022
..
ShashankMosaicML Nov 18, 2023
c94e7fe
..
ShashankMosaicML Nov 18, 2023
e96b234
..
ShashankMosaicML Nov 18, 2023
f02034b
..
ShashankMosaicML Nov 18, 2023
88c6808
..
ShashankMosaicML Nov 18, 2023
511a405
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 18, 2023
6af9aba
..
ShashankMosaicML Nov 18, 2023
5d7805d
..
ShashankMosaicML Nov 18, 2023
55625ff
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 21, 2023
44148b1
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
a8f63d4
..
ShashankMosaicML Nov 22, 2023
af6520a
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
538169c
..
ShashankMosaicML Nov 22, 2023
20af30a
..
ShashankMosaicML Nov 22, 2023
eeb9e1c
..
ShashankMosaicML Nov 22, 2023
1b7d38d
..
ShashankMosaicML Nov 22, 2023
b0a3c1b
..
ShashankMosaicML Nov 22, 2023
6a4f73e
..
ShashankMosaicML Nov 22, 2023
f05bfe6
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
c275365
..
ShashankMosaicML Nov 25, 2023
e82c723
..
ShashankMosaicML Nov 25, 2023
a964aea
..
ShashankMosaicML Nov 26, 2023
5765724
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 29, 2023
4240245
..
ShashankMosaicML Nov 30, 2023
4b25da2
..
ShashankMosaicML Nov 30, 2023
67deef8
..
ShashankMosaicML Nov 30, 2023
b855100
..
ShashankMosaicML Nov 30, 2023
371e3a2
..
ShashankMosaicML Nov 30, 2023
fa2a2ee
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 30, 2023
8339cd3
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 1, 2023
6c59dce
..
ShashankMosaicML Dec 1, 2023
805313b
..
ShashankMosaicML Dec 1, 2023
f1251c4
..
ShashankMosaicML Dec 1, 2023
5fca723
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 1, 2023
14a2553
merging from main
ShashankMosaicML Dec 2, 2023
cdc220f
merging from main
ShashankMosaicML Dec 2, 2023
e25ed63
..
ShashankMosaicML Dec 2, 2023
cb6864a
..
ShashankMosaicML Dec 2, 2023
9bc7ce1
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 76 additions & 38 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
key_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa
vchiley marked this conversation as resolved.
Show resolved Hide resolved

if multiquery:
warnings.warn(
Expand Down Expand Up @@ -219,6 +223,9 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -260,47 +267,69 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
if query_attention_mask_in_length is None:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)
else:
if key_attention_mask_in_length is None:
raise ValueError(
'key_attention_mask_in_length must not be None if query_attention_mask_in_length is not None.'
)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input_for_concatenated_sequences(
query, query_attention_mask_in_length)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input_for_concatenated_sequences(
key, key_attention_mask_in_length)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input_for_concatenated_sequences(
value, key_attention_mask_in_length)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)

if should_repeat_kv_for_gqa:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -357,8 +386,12 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
Expand Down Expand Up @@ -569,6 +602,8 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -640,6 +675,9 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
should_repeat_kv_for_gqa= not is_flash_v2_installed(),
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
)

return self.out_proj(context), attn_weights, past_key_value
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -124,6 +126,8 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
'attn_impl'] not in ['torch', 'triton'] and not (self.attn_config['attn_impl']=='flash' and is_flash_v2_installed(v2_version='v2.1.2')):
vchiley marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.'
vchiley marked this conversation as resolved.
Show resolved Hide resolved
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
Expand Down
67 changes: 42 additions & 25 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,44 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
)
raise ValueError('rope_impl needs to be either dail or hf')

def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, attn_uses_sequence_id: bool, attn_impl: str):
# Generates the attention masks used for sequence masking in flash attention
query_attention_mask_in_length = None
key_attention_mask_in_length = None
if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl == 'flash'):
query_attention_mask_in_length = torch.nn.functional.one_hot(sequence_id[:, -S:], num_classes=S).sum(dim=1)
# We use S as the number of classes while creating key_attention_mask_in_length instead of sequence_id.shape[-1]
# because in case of inference, sequence_id.shape[-1] can become very large. In that case, the one_hot vectors
# would've become very large as well.
key_attention_mask_in_length = torch.nn.functional.one_hot(sequence_id, num_classes=S).sum(dim=1)
# Since Flash Attention expects the masks to have same shape as the keys, we pad it with zeros.
key_attention_mask_in_length = torch.nn.functional.pad(key_attention_mask_in_length, (0, sequence_id.shape[-1] - S), value=0)

return query_attention_mask_in_length,key_attention_mask_in_length
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved

def apply_sequence_id(attn_bias: torch.Tensor,
sequence_id: torch.LongTensor,
max_seq_len: int) -> torch.Tensor:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
seq_len = sequence_id.shape[-1]
if seq_len > max_seq_len:
raise ValueError(
f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}'
)

# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]

# Restrict attention to tokens that share the same value
# in sequence_id
cannot_attend = torch.logical_not(
torch.eq(
sequence_id.view(-1, seq_len, 1),
sequence_id.view(-1, 1, seq_len),
)).unsqueeze(1)
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)

return attn_bias

class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
Expand Down Expand Up @@ -286,7 +324,7 @@ def _attn_bias(
# If using torch or triton, we incorporate sequence_id (if appropriate)
if self.attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)

# If using torch or triton, we incorporate attention_mask. This will output
# None in place of attention_mask since it will not be further needed in the
Expand Down Expand Up @@ -343,29 +381,6 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor,

return attn_bias

def _apply_sequence_id(self, attn_bias: torch.Tensor,
sequence_id: torch.LongTensor) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
if seq_len > self.config.max_seq_len:
raise ValueError(
f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
)

# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]

# Restrict attention to tokens that share the same value
# in sequence_id
cannot_attend = torch.logical_not(
torch.eq(
sequence_id.view(-1, seq_len, 1),
sequence_id.view(-1, 1, seq_len),
)).unsqueeze(1)
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)

return attn_bias

def forward(
self,
input_ids: torch.LongTensor,
Expand Down Expand Up @@ -509,7 +524,7 @@ def forward(
prefix_mask=prefix_mask,
sequence_id=sequence_id,
)

query_attention_mask_in_length, key_attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl)
# initialize the past key values cache if it should be used
presents = () if use_cache else None
if use_cache and past_key_values is None:
Expand All @@ -532,6 +547,8 @@ def forward(
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
)
if presents is not None:
presents += (present,)
Expand Down
Loading
Loading