Skip to content

Commit

Permalink
Fix llama attn
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 3, 2023
1 parent a0e64ba commit ad6fd55
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -136,6 +134,9 @@ def __init__(self, om_model_config: Union[DictConfig,
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False

if config.model_type == 'llama':
transformers.utils.is_flash_attn_available = lambda : False

# initialize the model on the correct device
if resolved_init_device == 'cpu':
Expand Down Expand Up @@ -193,6 +194,8 @@ def __init__(self, om_model_config: Union[DictConfig,
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
process_init_device,
update_batch_size_info)


def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
loaders = [cfg.train_loader]
Expand Down

0 comments on commit ad6fd55

Please sign in to comment.