diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d6b23c04d0..34b8992d3e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -739,6 +739,12 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: def activation_checkpointing_fn(self, module: nn.Module) -> bool: act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', None) or ['MPTBlock'] + if isinstance(act_ckpt_list, str): + act_ckpt_list = [act_ckpt_list] + elif not isinstance(act_ckpt_list, list): + raise ValueError( + f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}' + ) if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: if len(act_ckpt_list) > 1: diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py index 3b9a746708..a7e41a3fc2 100644 --- a/tests/test_fsdp_act_checkpoint.py +++ b/tests/test_fsdp_act_checkpoint.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import pytest from composer import Trainer from composer.utils import get_device, using_torch_2 @@ -14,11 +16,12 @@ @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) -@pytest.mark.parametrize( - 'activation_checkpointing_target', - [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +@pytest.mark.parametrize('activation_checkpointing_target', [ + 'grouped_query_attention', [], ['grouped_query_attention'], + ['mptblock', 'grouped_query_attention'] +]) def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: list): + activation_checkpointing_target: Union[list, str]): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', @@ -66,7 +69,9 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module._fpw_module assert isinstance(module, CheckpointWrapper) - elif activation_checkpointing_target == ['grouped_query_attention']: + elif activation_checkpointing_target == [ + 'grouped_query_attention' + ] or activation_checkpointing_target == 'grouped_query_attention': assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)