Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 18, 2023
1 parent 3d1d022 commit c94e7fe
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import math

from llmfoundry.models.layers.attention import flash_attn_fn
from llmfoundry.models.layers.attention import is_flash_v2_installed


@pytest.mark.gpu
@pytest.mark.parametrize('kv_n_heads', [1, 2, 4, 8])
def test_gqa_kv_repetition(kv_n_heads: int):
if not is_flash_v2_installed():
pytest.skip(
'GQA natively only supported by Flash Attention after v2.'
)
d = 128
n_heads = 8
seqlen_1 = 6
bsz = 2

query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda()
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=True)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=False)

output_2.sum().backward()
assert torch.allclose(output_1, output_2)
assert torch.allclose(query_1.grad, query_2.grad)
assert torch.allclose(key_1.grad, key_2.grad)
assert torch.allclose(value_1.grad, value_2.grad)


@pytest.mark.gpu
def test_seq_id_masking_FA_v2():
if not is_flash_v2_installed(v2_version='v2.1.2'):
pytest.skip(
'Using sequence id with flash attention requires flash attention v2.1.2 or higher.'
)
d = 128
n_heads = 4
kv_n_heads = 4
seqlen_1 = 6
bsz = 2

query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda()
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
value_1.requires_grad = True

seq_ranges = [(0, 3), (3, 5), (5, 6)] # Each batch has 3 sequences of length 3, 2, and 1 respectively.
query_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()
key_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()

output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=key_attention_mask_in_length_1,
query_attention_mask_in_length=query_attention_mask_in_length_1)

output_1.sum().backward()

for seq_range in seq_ranges:
query_2 = query_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
query_2.requires_grad = True
key_2 = key_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
key_2.requires_grad = True
value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None)

output_2.sum().backward()
assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], output_2)
assert torch.allclose(query_1.grad[:, seq_range[0]:seq_range[1], :],
query_2.grad)
assert torch.allclose(key_1.grad[:, seq_range[0]:seq_range[1], :],
key_2.grad)
assert torch.allclose(value_1.grad[:, seq_range[0]:seq_range[1], :],
value_2.grad)

0 comments on commit c94e7fe

Please sign in to comment.