From ad6fd558869d3b02feaa759a3ee78059c624d320 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 2 Oct 2023 15:47:52 -0700 Subject: [PATCH] Fix llama attn --- llmfoundry/models/hf/hf_causal_lm.py | 7 +++++-- llmfoundry/models/layers/llama_attention_monkeypatch.py | 1 + scripts/train/train.py | 1 - 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d5ef2435f9..b364b73bef 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -24,8 +24,6 @@ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss -from llmfoundry.models.layers.llama_attention_monkeypatch import \ - get_llama_attention_patch_fn from llmfoundry.models.utils import init_empty_weights try: @@ -136,6 +134,9 @@ def __init__(self, om_model_config: Union[DictConfig, # Rank 0 will still be pretrained, and distribute the weights appropriately if dist.get_local_rank() != 0 and init_device == 'mixed': om_model_config.pretrained = False + + if config.model_type == 'llama': + transformers.utils.is_flash_attn_available = lambda : False # initialize the model on the correct device if resolved_init_device == 'cpu': @@ -193,6 +194,8 @@ def __init__(self, om_model_config: Union[DictConfig, log.debug( f'Patching llama attention with {attention_patch_type} attention' ) + from llmfoundry.models.layers.llama_attention_monkeypatch import \ + get_llama_attention_patch_fn from transformers.models.llama.modeling_llama import \ LlamaAttention LlamaAttention.forward = get_llama_attention_patch_fn( diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py index 88f61e3fef..665648d06d 100644 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py @@ -186,6 +186,7 @@ def llama_attention_patch_triton( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if use_cache: raise NotImplementedError( diff --git a/scripts/train/train.py b/scripts/train/train.py index b31c15467e..96d383687b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -28,7 +28,6 @@ process_init_device, update_batch_size_info) - def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" loaders = [cfg.train_loader]