Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 14, 2023
1 parent a0e709b commit 4d944b9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
use_flash_attention_2 = om_model_config.get('use_flash_attention_2', is_flash_v2_installed())
use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
is_flash_v2_installed())
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' +
'Please install flash_attn==2.3.2`.'
)
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
Expand All @@ -111,9 +111,9 @@ def __init__(self, om_model_config: Union[DictConfig,
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we replicate
# validation logic that it is ok to use flash attention 2, so we replicate
# the most importance piece (is it installed) above.
config._flash_attn_2_enabled=use_flash_attention_2
config._flash_attn_2_enabled = use_flash_attention_2

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
if key not in {'gpu-flash2', 'all-cpu'})
extra_deps['all-flash2'] = set(
dep for key, deps in extra_deps.items() for dep in deps if key not in {'gpu', 'all', 'all-cpu'})
extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
for dep in deps
if key not in {'gpu', 'all', 'all-cpu'})

setup(
name=_PACKAGE_NAME,
Expand Down
32 changes: 18 additions & 14 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import pytest
import torch
import transformers
from composer.utils import reproducibility
from composer.core.precision import get_precision_context
from composer.utils import reproducibility
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import is_flash_v1_installed
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry import COMPOSER_MODEL_REGISTRY
from omegaconf import OmegaConf as om

# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
Expand Down Expand Up @@ -108,10 +108,11 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,

assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)


@pytest.mark.gpu
@pytest.mark.parametrize('model', ['llama2', 'mistral'])
def test_flash2(model: str):
if model == 'llama2':
def test_flash2(model_name: str):
if model_name == 'llama2':
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand All @@ -129,11 +130,12 @@ def test_flash2(model: str):
'init_device': 'cpu',
}
tokenizer_name = 'meta-llama/Llama-2-7b-hf'
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.llama.modeling_llama import \
LlamaFlashAttention2
flash_attn_class = LlamaFlashAttention2
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'
elif model == 'mistral':
elif model_name == 'mistral':
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1',
Expand All @@ -147,10 +149,13 @@ def test_flash2(model: str):
}

tokenizer_name = 'mistralai/Mistral-7B-v0.1'
from transformers.models.mistral.modeling_mistral import MistralFlashAttention2
from transformers.models.mistral.modeling_mistral import \
MistralFlashAttention2
flash_attn_class = MistralFlashAttention2
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'
else:
raise ValueError(f'Unknown model: {model_name}')

model_cfg = om.create(model_cfg)

Expand All @@ -164,10 +169,13 @@ def test_flash2(model: str):

# check that it actually used flash attention 2
assert model.model.config._flash_attn_2_enabled
attention_layer = rgetattr(rgetattr(model, attention_layers_attr)[0], attention_attr)
attention_layer = rgetattr(
rgetattr(model, attention_layers_attr)[0], attention_attr)
assert isinstance(attention_layer, flash_attn_class)

tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'], return_tensors='pt', padding=True)
tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
return_tensors='pt',
padding=True)
tokenized_input['labels'] = tokenized_input['input_ids'].clone()
print(tokenized_input)

Expand All @@ -179,7 +187,3 @@ def test_flash2(model: str):
outputs = model(tokenized_input)
loss = outputs.loss
loss.backward()




0 comments on commit 4d944b9

Please sign in to comment.