Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 2, 2023
1 parent e25ed63 commit cb6864a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
22 changes: 12 additions & 10 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@


@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(),
reason='GQA natively only supported by Flash Attention after v2.')
@pytest.mark.parametrize('kv_n_heads', [1, 4, 8])
def test_gqa_kv_repetition(kv_n_heads: int):
# Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same
# whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own.
if not is_flash_v2_installed():
pytest.skip('GQA natively only supported by Flash Attention after v2.')
d = 128
n_heads = 8
seqlen_1 = 6
Expand Down Expand Up @@ -82,12 +83,13 @@ def test_gqa_kv_repetition(kv_n_heads: int):


@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.1.2'),
reason=
'Using sequence id with flash attention requires flash attention v2.1.2 or higher.'
)
def test_seq_id_masking_FA_v2():
# Test that flash attention v2 with sequence id masking works correctly.
if not is_flash_v2_installed(v2_version='v2.1.2'):
pytest.skip(
'Using sequence id with flash attention requires flash attention v2.1.2 or higher.'
)
d = 128
n_heads = 4
kv_n_heads = 4
Expand Down Expand Up @@ -167,13 +169,13 @@ def test_seq_id_masking_FA_v2():


@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.3.0'),
reason=
'Sliding window attention only supported by Flash Attention after v2.3.0.')
@pytest.mark.parametrize('sliding_window_size', [1, 4, 8])
def test_sliding_window(sliding_window_size: int):
# Test that sliding window attention works as expected.
if not is_flash_v2_installed('v2.3.0'):
pytest.skip(
'Sliding window attention only supported by Flash Attention after v2.3.0.'
)
dtype = torch.bfloat16
device = 'cuda'
d = 128
Expand Down
5 changes: 1 addition & 4 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
# Testing the output of concatenated sequence with sequence id masking vs individual sequences.
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash':
Expand Down Expand Up @@ -620,7 +618,6 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict,
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt.eval()
Expand Down

0 comments on commit cb6864a

Please sign in to comment.