From a471278520f48c789878cd7b95d2d0ec5b8095fd Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 9 Apr 2024 15:25:48 -0700 Subject: [PATCH] Add FP8 TransformerEngine activation checkpointing (#3156) * add te checkpoint wrapper * remove extra wrapper * add option to shard fp8 weights * rename * rename * add log info * fix te checkpoint * update format * add comment --------- Co-authored-by: Charles Tang Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- composer/trainer/dist_strategy.py | 56 ++++++++++++++++++++++++++++--- composer/trainer/trainer.py | 10 +++++- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 8fab1a8c54..8255a80f39 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -147,6 +147,8 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]): fsdp_config.setdefault('activation_checkpointing', False) fsdp_config.setdefault('activation_checkpointing_reentrant', True) fsdp_config.setdefault('activation_cpu_offload', False) + fsdp_config.setdefault('te_checkpoint_wrapper', False) + fsdp_config.setdefault('te_shard_fp8_weight', False) fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST') fsdp_config.setdefault('backward_prefetch_limit', 1) fsdp_config.setdefault('cpu_offload', False) @@ -214,6 +216,7 @@ def prepare_fsdp_module( precision: Precision, device: Device, auto_microbatching: bool, + te_rng_seed: int = 1234, ) -> None: """Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`. @@ -224,6 +227,7 @@ def prepare_fsdp_module( precision: (Precision): The precision being used by the Trainer, used to fill in defaults for FSDP `mixed_precision` settings. device (Device): The device being used by the Trainer. auto_microbatching (bool, optional): Whether or not auto microbatching is enabled. + te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234. """ patch_pytorch() @@ -388,6 +392,8 @@ def sync_hook(*args): ignored_modules = fsdp_config['ignored_modules'] state_dict_type = fsdp_config['state_dict_type'] activation_checkpointing_reentrant = fsdp_config['activation_checkpointing_reentrant'] + te_checkpoint_wrapper = fsdp_config['te_checkpoint_wrapper'] if precision == Precision.AMP_FP8 else False + te_shard_fp8_weight = fsdp_config['te_shard_fp8_weight'] if precision == Precision.AMP_FP8 else False sharded_ckpt_prefix_dir = fsdp_config['sharded_ckpt_prefix_dir'] use_orig_params = fsdp_config['use_orig_params'] @@ -598,6 +604,14 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num **kwargs, ) + if te_shard_fp8_weight: + try: + from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp + except ModuleNotFoundError: + raise ModuleNotFoundError('Please install transformer-engine to use prepare_te_modules_for_fsdp') + log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding') + prepare_te_modules_for_fsdp(fsdp_obj) + if hasattr(fsdp_obj, '_exec_order_data'): if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'): fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit'] @@ -624,13 +638,39 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num # Activation Checkpointing if activation_checkpointing or activation_cpu_offload: + # FP8 TE requires using the TE checkpoint function, FSDP activation checkpointing only works with TE non-reentrant checkpointing + if te_checkpoint_wrapper: + assert not activation_checkpointing_reentrant, 'TE checkpoint only works with non-reentrant checkpointing' if version.parse(torch.__version__) > version.parse('2.1.0.dev'): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper if not activation_checkpointing_reentrant: - first_wrap_fn = lambda m: checkpoint_wrapper( - m, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) if activation_checkpointing else (lambda module: module) + if te_checkpoint_wrapper: + try: + import transformer_engine.pytorch as te + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Please install transformer-engine to use TE checkpoint wrapper', + ) + + # RNG state tracker for checkpointing + CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker() + CUDA_RNG_STATES_TRACKER.add('fsdp-rng', te_rng_seed) + + def get_cuda_rng_tracker(): + return CUDA_RNG_STATES_TRACKER + + first_wrap_fn = lambda m: checkpoint_wrapper( + m, + context_fn=te.distributed.get_activation_recompute_contexts, + checkpoint_fn=te.distributed.checkpoint, + use_reentrant=False, + get_rng_state_tracker=get_cuda_rng_tracker, + ) + else: + first_wrap_fn = lambda m: checkpoint_wrapper( + m, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) if activation_checkpointing else (lambda module: module) second_wrap_fn = ( lambda module: offload_wrapper( first_wrap_fn(module) @@ -638,7 +678,11 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num ) ) if activation_cpu_offload else first_wrap_fn else: - first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module) + + first_wrap_fn = lambda m: checkpoint_wrapper( + m, + checkpoint_impl=CheckpointImpl.REENTRANT, + ) if activation_checkpointing else (lambda module: module) second_wrap_fn = ( lambda module: offload_wrapper( first_wrap_fn(module) @@ -699,6 +743,8 @@ def _check_fn(module: torch.nn.Module) -> bool: log.info(f'FSDP: Using backward_prefetch={backward_prefetch}') log.info(f'FSDP: Using activation_checkpointing={activation_checkpointing}') log.info(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}') + log.info(f'FSDP: Using te_checkpoint_wrapper={te_checkpoint_wrapper}') + log.info(f'FSDP: Using te_shard_fp8_weight={te_shard_fp8_weight}') log.info(f'FSDP: Using sync_module_states={sync_module_states}') log.info(f'FSDP: Using forward_prefetch={forward_prefetch}') log.info(f'FSDP: Using limit_all_gathers={limit_all_gathers}') diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6ff090988e..4020f76559 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1550,7 +1550,15 @@ def __init__( # FSDP wrap if not using monolith checkpoint on rank 0 only if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only: with reproducibility.seed_context(self.state.rank_zero_seed): - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + prepare_fsdp_module( + model, + optimizers, + self.state.fsdp_config, + precision, + device, + auto_microbatching, + self.state.seed, + ) # Configure Deepspeed if self.state.deepspeed_config is not None: