Skip to content

Commit

Permalink
Revert "deprecations"
Browse files Browse the repository at this point in the history
This reverts commit 6858db9.
  • Loading branch information
dakinggg committed Sep 25, 2024
1 parent e27bb7b commit 3b316c4
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 4 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_model(
warnings.warn(
VersionedDeprecationWarning(
'The argument fsdp_config is deprecated. Please use parallelism_config instead.',
remove_version='0.14.0',
remove_version='0.13.0',
),
)
if fsdp_config and parallelism_config:
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
prepare_hf_model_for_fsdp,
)
from llmfoundry.models.hf.hf_t5 import ComposerHFT5
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP

__all__ = [
'BaseHuggingFaceModel',
Expand All @@ -17,4 +18,5 @@
'prepare_hf_causal_lm_model_for_fsdp',
'prepare_hf_enc_dec_model_for_fsdp',
'prepare_hf_model_for_fsdp',
'HuggingFaceModelWithFSDP',
]
103 changes: 103 additions & 0 deletions llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Re-usable :class:`.ComposerModel` for LLM HF Models."""

from __future__ import annotations

import warnings
from collections import UserDict
from typing import TYPE_CHECKING, Mapping, Optional, Union

import transformers
from composer.models.huggingface import HuggingFaceModel
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase
from transformers.utils.generic import ModelOutput

from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
from llmfoundry.utils.warnings import VersionedDeprecationWarning

if TYPE_CHECKING:
from peft import PeftConfig, PeftModel

__all__ = ['HuggingFaceModelWithFSDP']

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100


class HuggingFaceModelWithFSDP(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.
Handles preparation for FSDP wrapping.
"""

def __init__(
self,
model: Union[transformers.PreTrainedModel, 'PeftModel'],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metrics: Optional[list[Metric]] = None,
eval_metrics: Optional[list[Metric]] = None,
shift_labels: bool = False,
allow_embedding_resizing: bool = False,
init_device: Optional[str] = None,
peft_config: Optional['PeftConfig'] = None,
should_save_peft_only: bool = True,
):
warnings.warn(
VersionedDeprecationWarning(
'`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.',
remove_version='0.13.0',
),
)
super().__init__(
model,
tokenizer,
use_logits=True,
metrics=metrics,
eval_metrics=eval_metrics,
shift_labels=shift_labels,
allow_embedding_resizing=allow_embedding_resizing,
peft_config=peft_config,
should_save_peft_only=should_save_peft_only,
)

self.prepare_inner_model(self.model, init_device)

def forward(self, batch: Mapping):
if isinstance(batch, dict) or isinstance(batch, UserDict):
# Further input validation is left to the huggingface forward call
batch = {
k: v for k, v in batch.items() if k in self.model_forward_args
}
output = self.model(**batch) # type: ignore (thirdparty)
else:
raise ValueError(
'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model',
)
return output

def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
return outputs['loss']
# loss is at index 0 in the output tuple, logits are at index 1
return outputs[:2]

@staticmethod
def prepare_inner_model(
model: Union[transformers.PreTrainedModel, 'PeftModel'],
init_device: Optional[str] = None,
):
"""Prepare the inner model for FSDP wrapping.
Args:
model: The model to prepare.
init_device: The device to initialize the model on.
"""
# Note: We need to add the FSDP related attributes to the model AFTER the super init,
# so that the (possible) embedding resizing doesn't destroy them
prepare_hf_model_for_fsdp(model, init_device)

# This provides support for meta initialization when using FSDP
model.param_init_fn = lambda module: model._init_weights(module)
6 changes: 3 additions & 3 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
from llmfoundry.models.hf import BaseHuggingFaceModel
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers import build_alibi_bias
from llmfoundry.models.layers.attention import (
check_alibi_support,
Expand Down Expand Up @@ -2560,7 +2560,7 @@ def test_hf_init(
False,
)

model = BaseHuggingFaceModel(model, tokenizer)
model = HuggingFaceModelWithFSDP(model, tokenizer)

batch = gen_random_batch(batch_size, test_cfg)

Expand Down Expand Up @@ -2609,7 +2609,7 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2):

mpt = MPTForCausalLM(hf_config)

model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True)
model = HuggingFaceModelWithFSDP(mpt, tokenizer, shift_labels=True)

model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)
Expand Down

0 comments on commit 3b316c4

Please sign in to comment.