diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index a2e2ad3cdc..746499fdfb 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -16,6 +16,7 @@ LanguageCrossEntropy, LanguagePerplexity) from composer.utils import dist from omegaconf import DictConfig +from torch import nn from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase) @@ -28,12 +29,9 @@ try: from peft.peft_model import PeftModel model_types = PeftModel, transformers.PreTrainedModel - _om_model_config_type = Union[DictConfig, PeftModel, - transformers.PreTrainedModel] except ImportError: model_types = transformers.PreTrainedModel - _om_model_config_type = Union[DictConfig, transformers.PreTrainedModel] __all__ = ['ComposerHFCausalLM'] @@ -58,21 +56,10 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__( - self, - om_model_config: _om_model_config_type, # type: ignore - tokenizer: PreTrainedTokenizerBase): - - if not om_model_config.get('trust_remote_code', - True) and om_model_config.get( - 'pretrained_model_name_or_path', - None).startswith('mosaicml/mpt'): - raise ValueError( - 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' - + - 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' - ) - + def __init__(self, om_model_config: Union[DictConfig, + transformers.PreTrainedModel, + nn.Module], + tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics train_metrics = [ LanguageCrossEntropy(), @@ -90,6 +77,15 @@ def __init__( # if we are passed a DictConfig, we need to instantiate the model if isinstance(om_model_config, DictConfig): + if not om_model_config.get('trust_remote_code', + True) and om_model_config.get( + 'pretrained_model_name_or_path', + None).startswith('mosaicml/mpt'): + raise ValueError( + 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' + + + 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' + ) # load the model config trust_remote_code = om_model_config.get('trust_remote_code', True) @@ -181,6 +177,23 @@ def __init__( z_loss = om_model_config.get('z_loss', 0.0) + attention_patch_type = om_model_config.get('attention_patch_type', + None) + if attention_patch_type is not None: + if model.config.model_type != 'llama': + raise ValueError( + f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' + ) + + print( + f'Patching llama attention with {attention_patch_type} attention' + ) + from transformers.models.llama.modeling_llama import \ + LlamaAttention + LlamaAttention.forward = get_llama_attention_patch_fn( + attention_patch_type) + model.config.use_cache = False + # elif the model is either a PeftModel or a PreTrainedModel elif isinstance(om_model_config, model_types): model = om_model_config @@ -193,21 +206,6 @@ def __init__( f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}' ) - attention_patch_type = om_model_config.get('attention_patch_type', None) - if attention_patch_type is not None: - if model.config.model_type != 'llama': - raise ValueError( - f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' - ) - - print( - f'Patching llama attention with {attention_patch_type} attention' - ) - from transformers.models.llama.modeling_llama import LlamaAttention - LlamaAttention.forward = get_llama_attention_patch_fn( - attention_patch_type) - model.config.use_cache = False - composer_model = super().__init__(model=model, shift_labels=True, tokenizer=tokenizer,