diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index 6b693f2d21..d193e1067f 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -69,7 +69,7 @@ def __init__( config_overrides: Optional[dict[str, Any]] = None, use_logits: bool = True, shift_labels: bool = False, - peft_config: Optional['PeftConfig'] = None, + peft_config: Optional[dict[str, Any]] = None, allow_embedding_resizing: bool = False, use_train_metrics: bool = True, additional_train_metrics: Optional[list] = None, @@ -92,8 +92,6 @@ def __init__( model = self.transform_model(model) - self.prepare_inner_model(model, init_device) - metrics, eval_metrics = self.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -121,6 +119,10 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + # Prepare for FSDP needs to happen after the super init, so that any model + # architecture changes are completed + self.prepare_inner_model(self.model, init_device) + def loss(self, outputs: ModelOutput, batch: Mapping): if self.config.use_return_dict: return outputs['loss'] diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 522fc5db57..56cb36c8c1 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -11,6 +11,7 @@ from composer import Trainer from peft import LoraConfig, get_peft_model +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp from llmfoundry.utils.builders import build_composer_model, build_tokenizer @@ -36,6 +37,27 @@ def test_peft_wraps(): assert m._fsdp_wrap +def test_causal_lm_peft_wraps(): + model = ComposerHFCausalLM( + tokenizer=None, + pretrained_model_name_or_path='mosaicml/mpt-7b', + pretrained=False, + trust_remote_code=True, + config_overrides={'n_layers': 2}, + peft_config={ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + }, + ) + + for n, m in model.named_modules(): + if 'lora' in n and 'default' in n: + has_parameters = any(True for _ in m.parameters()) + has_buffers = any(True for _ in m.buffers()) + if has_parameters or has_buffers: + assert m._fsdp_wrap + + @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize(