Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shashank/seq id flash attn #738

Merged
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
04dd334
Merge pull request #1 from mosaicml/main
ShashankMosaicML Oct 9, 2023
87b2fdc
Merge pull request #8 from mosaicml/main
ShashankMosaicML Oct 27, 2023
c9a42e4
Merge pull request #12 from mosaicml/main
ShashankMosaicML Nov 6, 2023
ddea9ee
Merge branch 'mosaicml:main' into main
ShashankMosaicML Nov 6, 2023
0bcd8ee
Merge pull request #13 from mosaicml/main
ShashankMosaicML Nov 8, 2023
f209b58
Merge pull request #14 from mosaicml/main
ShashankMosaicML Nov 14, 2023
879edb2
..
ShashankMosaicML Nov 14, 2023
66d85f1
..
ShashankMosaicML Nov 14, 2023
c1c2fbd
..
ShashankMosaicML Nov 15, 2023
e8b9381
..
ShashankMosaicML Nov 15, 2023
57502f4
..
ShashankMosaicML Nov 15, 2023
3041bf6
..
ShashankMosaicML Nov 15, 2023
9b305df
..
ShashankMosaicML Nov 15, 2023
b5a3c1f
..
ShashankMosaicML Nov 15, 2023
2e60014
..
ShashankMosaicML Nov 15, 2023
bf946c7
..
ShashankMosaicML Nov 15, 2023
e99780e
..
ShashankMosaicML Nov 15, 2023
897edbc
..
ShashankMosaicML Nov 15, 2023
88cf6d3
..
ShashankMosaicML Nov 15, 2023
c0a4d97
..
ShashankMosaicML Nov 15, 2023
ec4378d
Merge pull request #15 from mosaicml/main
ShashankMosaicML Nov 15, 2023
ad36ecf
merged and resolved conflicts
ShashankMosaicML Nov 15, 2023
c09d5c9
..
ShashankMosaicML Nov 15, 2023
efc4f4e
..
ShashankMosaicML Nov 15, 2023
4940e70
..
ShashankMosaicML Nov 15, 2023
5e3ccf9
..
ShashankMosaicML Nov 16, 2023
84fa710
Update llmfoundry/models/layers/attention.py
ShashankMosaicML Nov 17, 2023
a560f31
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Nov 17, 2023
a70f05e
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Nov 17, 2023
42a541d
..
ShashankMosaicML Nov 17, 2023
3d1d022
..
ShashankMosaicML Nov 18, 2023
c94e7fe
..
ShashankMosaicML Nov 18, 2023
e96b234
..
ShashankMosaicML Nov 18, 2023
f02034b
..
ShashankMosaicML Nov 18, 2023
88c6808
..
ShashankMosaicML Nov 18, 2023
511a405
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 18, 2023
6af9aba
..
ShashankMosaicML Nov 18, 2023
5d7805d
..
ShashankMosaicML Nov 18, 2023
55625ff
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 21, 2023
44148b1
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
a8f63d4
..
ShashankMosaicML Nov 22, 2023
af6520a
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
538169c
..
ShashankMosaicML Nov 22, 2023
20af30a
..
ShashankMosaicML Nov 22, 2023
eeb9e1c
..
ShashankMosaicML Nov 22, 2023
1b7d38d
..
ShashankMosaicML Nov 22, 2023
b0a3c1b
..
ShashankMosaicML Nov 22, 2023
6a4f73e
..
ShashankMosaicML Nov 22, 2023
f05bfe6
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 22, 2023
c275365
..
ShashankMosaicML Nov 25, 2023
e82c723
..
ShashankMosaicML Nov 25, 2023
a964aea
..
ShashankMosaicML Nov 26, 2023
5765724
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 29, 2023
4240245
..
ShashankMosaicML Nov 30, 2023
4b25da2
..
ShashankMosaicML Nov 30, 2023
67deef8
..
ShashankMosaicML Nov 30, 2023
b855100
..
ShashankMosaicML Nov 30, 2023
371e3a2
..
ShashankMosaicML Nov 30, 2023
fa2a2ee
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Nov 30, 2023
8339cd3
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 1, 2023
6c59dce
..
ShashankMosaicML Dec 1, 2023
805313b
..
ShashankMosaicML Dec 1, 2023
f1251c4
..
ShashankMosaicML Dec 1, 2023
5fca723
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 1, 2023
14a2553
merging from main
ShashankMosaicML Dec 2, 2023
cdc220f
merging from main
ShashankMosaicML Dec 2, 2023
e25ed63
..
ShashankMosaicML Dec 2, 2023
cb6864a
..
ShashankMosaicML Dec 2, 2023
9bc7ce1
Merge branch 'main' into shashank/seq_id_flash_attn
ShashankMosaicML Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 88 additions & 39 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
key_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size

if multiquery:
warnings.warn(
Expand Down Expand Up @@ -219,6 +224,10 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -260,47 +269,69 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
if query_attention_mask_in_length is None:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)
else:
if key_attention_mask_in_length is None:
raise ValueError(
'key_attention_mask_in_length must not be None if query_attention_mask_in_length is not None.'
)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input_for_concatenated_sequences(
query, query_attention_mask_in_length)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input_for_concatenated_sequences(
key, key_attention_mask_in_length)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input_for_concatenated_sequences(
value, key_attention_mask_in_length)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)

if should_repeat_kv_for_gqa:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -331,7 +362,8 @@ def flash_attn_fn(
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size))
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
Expand All @@ -357,8 +389,13 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
Expand Down Expand Up @@ -490,6 +527,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__()

Expand All @@ -500,6 +538,7 @@ def __init__(
self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size

self.head_dim = d_model // n_heads

Expand Down Expand Up @@ -569,6 +608,8 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -640,6 +681,10 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
should_repeat_kv_for_gqa=not is_flash_v2_installed(),
sliding_window_size=self.sliding_window_size,
)

return self.out_proj(context), attn_weights, past_key_value
Expand All @@ -665,6 +710,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -679,6 +725,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand All @@ -702,6 +749,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -716,6 +764,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand Down
5 changes: 5 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
Expand Down Expand Up @@ -113,6 +114,8 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -124,6 +127,8 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
15 changes: 12 additions & 3 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
rope (bool): Whether to use rotary positional embeddings.
Expand Down Expand Up @@ -221,10 +222,12 @@ def _validate_config(self) -> None:
]:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
if self.attn_config['attn_uses_sequence_id'] and not (
vchiley marked this conversation as resolved.
Show resolved Hide resolved
self.attn_config['attn_impl'] in ['torch', 'triton'] or
(self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.1.2'))):
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.'
vchiley marked this conversation as resolved.
Show resolved Hide resolved
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
Expand All @@ -251,6 +254,12 @@ def _validate_config(self) -> None:
raise ImportError(
'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support'
)
if self.attn_config['sliding_window_size'] != -1 and not (
self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.3.0')):
raise NotImplementedError(
'sliding window only implemented with flash attention v2.3.0 or higher.'
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
Expand Down
Loading
Loading