Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 14, 2023
1 parent adac5c0 commit 6ae2393
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import is_flash_v1_installed, is_flash_v2_installed
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.utils.builders import build_tokenizer

# Before importing any transformers models, we need to disable transformers flash attention if
Expand Down Expand Up @@ -134,9 +135,10 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
model_cfg['use_flash_attention_2'] = True

tokenizer_name = 'meta-llama/Llama-2-7b-hf'
from transformers.models.llama.modeling_llama import \
LlamaFlashAttention2, LlamaAttention
flash_attn_class = LlamaFlashAttention2 if is_flash_v2_installed() else LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaFlashAttention2)
flash_attn_class = LlamaFlashAttention2 if is_flash_v2_installed(
) else LlamaAttention
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'
elif model_name == 'mistral':
Expand All @@ -153,9 +155,10 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
}

tokenizer_name = 'mistralai/Mistral-7B-v0.1'
from transformers.models.mistral.modeling_mistral import \
MistralFlashAttention2, MistralAttention
flash_attn_class = MistralFlashAttention2 if is_flash_v2_installed() else MistralAttention
from transformers.models.mistral.modeling_mistral import (
MistralAttention, MistralFlashAttention2)
flash_attn_class = MistralFlashAttention2 if is_flash_v2_installed(
) else MistralAttention
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'
else:
Expand All @@ -172,7 +175,8 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)

# check that it actually used flash attention 2
assert model.model.config._flash_attn_2_enabled if is_flash_v2_installed() else not model.model.config._flash_attn_2_enabled
assert model.model.config._flash_attn_2_enabled if is_flash_v2_installed(
) else not model.model.config._flash_attn_2_enabled
attention_layer = rgetattr(
rgetattr(model, attention_layers_attr)[0], attention_attr)
assert isinstance(attention_layer, flash_attn_class)
Expand Down

0 comments on commit 6ae2393

Please sign in to comment.