Skip to content

Commit

Permalink
Fix ComposerHFCausalLM instantiation with PeftModel (#593)
Browse files Browse the repository at this point in the history
* Fix bug in hf_causal_lm, causing errors with evaluating peft models

* Move attention patch

* Fix typing
  • Loading branch information
irenedea committed Sep 12, 2023
1 parent dc13748 commit e5c243c
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
PreTrainedTokenizerBase)

Expand All @@ -28,12 +29,9 @@
try:
from peft.peft_model import PeftModel
model_types = PeftModel, transformers.PreTrainedModel
_om_model_config_type = Union[DictConfig, PeftModel,
transformers.PreTrainedModel]

except ImportError:
model_types = transformers.PreTrainedModel
_om_model_config_type = Union[DictConfig, transformers.PreTrainedModel]

__all__ = ['ComposerHFCausalLM']

Expand All @@ -58,21 +56,10 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""

def __init__(
self,
om_model_config: _om_model_config_type, # type: ignore
tokenizer: PreTrainedTokenizerBase):

if not om_model_config.get('trust_remote_code',
True) and om_model_config.get(
'pretrained_model_name_or_path',
None).startswith('mosaicml/mpt'):
raise ValueError(
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
+
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)

def __init__(self, om_model_config: Union[DictConfig,
transformers.PreTrainedModel,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [
LanguageCrossEntropy(),
Expand All @@ -90,6 +77,15 @@ def __init__(

# if we are passed a DictConfig, we need to instantiate the model
if isinstance(om_model_config, DictConfig):
if not om_model_config.get('trust_remote_code',
True) and om_model_config.get(
'pretrained_model_name_or_path',
None).startswith('mosaicml/mpt'):
raise ValueError(
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
+
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)

# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
Expand Down Expand Up @@ -181,6 +177,23 @@ def __init__(

z_loss = om_model_config.get('z_loss', 0.0)

attention_patch_type = om_model_config.get('attention_patch_type',
None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

print(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False

# elif the model is either a PeftModel or a PreTrainedModel
elif isinstance(om_model_config, model_types):
model = om_model_config
Expand All @@ -193,21 +206,6 @@ def __init__(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)

attention_patch_type = om_model_config.get('attention_patch_type', None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

print(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False

composer_model = super().__init__(model=model,
shift_labels=True,
tokenizer=tokenizer,
Expand Down

0 comments on commit e5c243c

Please sign in to comment.