Skip to content

Commit

Permalink
Merge branch 'dev' into shashank/fix_seq_parallel_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Apr 9, 2024
2 parents e874476 + a471278 commit 6f96e01
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
56 changes: 51 additions & 5 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('activation_checkpointing', False)
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('te_checkpoint_wrapper', False)
fsdp_config.setdefault('te_shard_fp8_weight', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
Expand Down Expand Up @@ -214,6 +216,7 @@ def prepare_fsdp_module(
precision: Precision,
device: Device,
auto_microbatching: bool,
te_rng_seed: int = 1234,
) -> None:
"""Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
Expand All @@ -224,6 +227,7 @@ def prepare_fsdp_module(
precision: (Precision): The precision being used by the Trainer, used to fill in defaults for FSDP `mixed_precision` settings.
device (Device): The device being used by the Trainer.
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
patch_pytorch()

Expand Down Expand Up @@ -388,6 +392,8 @@ def sync_hook(*args):
ignored_modules = fsdp_config['ignored_modules']
state_dict_type = fsdp_config['state_dict_type']
activation_checkpointing_reentrant = fsdp_config['activation_checkpointing_reentrant']
te_checkpoint_wrapper = fsdp_config['te_checkpoint_wrapper'] if precision == Precision.AMP_FP8 else False
te_shard_fp8_weight = fsdp_config['te_shard_fp8_weight'] if precision == Precision.AMP_FP8 else False
sharded_ckpt_prefix_dir = fsdp_config['sharded_ckpt_prefix_dir']
use_orig_params = fsdp_config['use_orig_params']

Expand Down Expand Up @@ -598,6 +604,14 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num
**kwargs,
)

if te_shard_fp8_weight:
try:
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
except ModuleNotFoundError:
raise ModuleNotFoundError('Please install transformer-engine to use prepare_te_modules_for_fsdp')
log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding')
prepare_te_modules_for_fsdp(fsdp_obj)

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit']
Expand All @@ -624,21 +638,51 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num

# Activation Checkpointing
if activation_checkpointing or activation_cpu_offload:
# FP8 TE requires using the TE checkpoint function, FSDP activation checkpointing only works with TE non-reentrant checkpointing
if te_checkpoint_wrapper:
assert not activation_checkpointing_reentrant, 'TE checkpoint only works with non-reentrant checkpointing'
if version.parse(torch.__version__) > version.parse('2.1.0.dev'):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
if not activation_checkpointing_reentrant:
first_wrap_fn = lambda m: checkpoint_wrapper(
m,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
) if activation_checkpointing else (lambda module: module)
if te_checkpoint_wrapper:
try:
import transformer_engine.pytorch as te
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Please install transformer-engine to use TE checkpoint wrapper',
)

# RNG state tracker for checkpointing
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add('fsdp-rng', te_rng_seed)

def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER

first_wrap_fn = lambda m: checkpoint_wrapper(
m,
context_fn=te.distributed.get_activation_recompute_contexts,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker,
)
else:
first_wrap_fn = lambda m: checkpoint_wrapper(
m,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
) if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: offload_wrapper(
first_wrap_fn(module)
if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues
)
) if activation_cpu_offload else first_wrap_fn
else:
first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module)

first_wrap_fn = lambda m: checkpoint_wrapper(
m,
checkpoint_impl=CheckpointImpl.REENTRANT,
) if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: offload_wrapper(
first_wrap_fn(module)
Expand Down Expand Up @@ -699,6 +743,8 @@ def _check_fn(module: torch.nn.Module) -> bool:
log.info(f'FSDP: Using backward_prefetch={backward_prefetch}')
log.info(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
log.info(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
log.info(f'FSDP: Using te_checkpoint_wrapper={te_checkpoint_wrapper}')
log.info(f'FSDP: Using te_shard_fp8_weight={te_shard_fp8_weight}')
log.info(f'FSDP: Using sync_module_states={sync_module_states}')
log.info(f'FSDP: Using forward_prefetch={forward_prefetch}')
log.info(f'FSDP: Using limit_all_gathers={limit_all_gathers}')
Expand Down
10 changes: 9 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,15 @@ def __init__(
# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only:
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config,
precision,
device,
auto_microbatching,
self.state.seed,
)

# Configure Deepspeed
if self.state.deepspeed_config is not None:
Expand Down

0 comments on commit 6f96e01

Please sign in to comment.