Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 14, 2023
1 parent c1d00e9 commit 0ea0d7d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
'init_device': 'cpu',
}

if use_flash_attention_2:
model_cfg['use_flash_attention_2'] = True

tokenizer_name = 'meta-llama/Llama-2-7b-hf'
from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaFlashAttention2)
Expand All @@ -146,7 +143,6 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1',
'use_flash_attention_2': True,
'config_overrides': {
'num_hidden_layers': 2,
'intermediate_size': 64,
Expand All @@ -164,6 +160,9 @@ def test_flash2(model_name: str, use_flash_attention_2: bool):
attention_attr = 'self_attn'
else:
raise ValueError(f'Unknown model: {model_name}')

if use_flash_attention_2:
model_cfg['use_flash_attention_2'] = True

model_cfg = om.create(model_cfg)

Expand Down

0 comments on commit 0ea0d7d

Please sign in to comment.