From a60bf3a16624b0cd3c1026657dfd61c9f74b16be Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 3 Jun 2024 22:56:29 -0400 Subject: [PATCH] Dataclasses for ParallelismConfig (#3346) * v1 paralleism * fix * add doc strings * lint * fix tests * clean u ptest * fix error * check if dict instances are configs * fix tests * fix lint * fix tests * fix test --------- Co-authored-by: Saaketh Narayan Co-authored-by: Your Name --- composer/callbacks/checkpoint_saver.py | 2 +- composer/core/state.py | 48 +++++----- composer/distributed/__init__.py | 2 - composer/distributed/dist_strategy.py | 74 +++++++------- .../{mosaic_fsdp.py => mosaic_parallelism.py} | 65 +------------ composer/trainer/trainer.py | 68 ++++++++----- composer/utils/__init__.py | 9 ++ composer/utils/checkpoint.py | 8 +- composer/utils/object_store/__init__.py | 8 +- composer/utils/parallelism.py | 96 +++++++++++++++++++ tests/test_events.py | 4 - tests/trainer/test_fsdp_checkpoint.py | 52 +++++----- 12 files changed, 250 insertions(+), 186 deletions(-) rename composer/distributed/{mosaic_fsdp.py => mosaic_parallelism.py} (72%) create mode 100644 composer/utils/parallelism.py diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index ab3fafd58d..263558fc2b 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -469,7 +469,7 @@ def _save_checkpoint(self, state: State, logger: Logger): keep_placeholders=True, ).lstrip('/') assert state.fsdp_config is not None - remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir'] + remote_prefix = state.fsdp_config.sharded_ckpt_prefix_dir assert remote_prefix is not None ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename) diff --git a/composer/core/state.py b/composer/core/state.py index 9e96b07127..083b977811 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -43,7 +43,10 @@ from composer.core.time import Time, Timestamp, TimeUnit, ensure_time from composer.devices import Device from composer.utils import ( + FSDPConfig, + ParallelismConfig, ParallelismType, + TPConfig, VersionedDeprecationWarning, batch_get, batch_set, @@ -197,8 +200,8 @@ def _ensure_backwards_compatible_checkpointing(state_dict: dict[str, Any]): def _create_device_mesh( device: Device, - fsdp_config: Optional[dict[str, Any]], - tp_config: Optional[dict[str, Any]], + fsdp_config: Optional[FSDPConfig], + tp_config: Optional[TPConfig], ) -> Optional[DeviceMesh]: if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'): # Device mesh has correctness issues before torch 2.3.0 @@ -210,13 +213,13 @@ def _create_device_mesh( # Gather dimensions and names for the device mesh dims: list[int] = [] names: list[str] = [] - if fsdp_config['data_parallel_replicate_degree'] is not None: - dims.append(fsdp_config['data_parallel_replicate_degree']) + if fsdp_config.data_parallel_replicate_degree is not None: + dims.append(fsdp_config.data_parallel_replicate_degree) names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value) - dims.append(fsdp_config['data_parallel_shard_degree']) + dims.append(fsdp_config.data_parallel_shard_degree) names.append(ParallelismType.DATA_PARALLEL_SHARD.value) if tp_config is not None: - dims.append(tp_config['tensor_parallel_degree']) + dims.append(tp_config.tensor_parallel_degree) names.append(ParallelismType.TENSOR_PARALLEL.value) # Fill in the unspecified dimensions @@ -329,7 +332,7 @@ class State(Serializable): algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training. callbacks (Callback | Sequence[Callback], optional): The callbacks used for training. deepspeed_config (dict[str, Any], optional): The configuration dictionary for deepspeed. - parallelism_config (dict[str, Any], optional): The configuration dictionary for parallelism. + parallelism_config (ParallelismConfig, optional): The configuration dictionary for parallelism. Attributes: batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a @@ -496,7 +499,7 @@ def __init__( # Distributed training configs deepspeed_config: Optional[dict[str, Any]] = None, - parallelism_config: Optional[dict[str, Any]] = None, + parallelism_config: Optional[ParallelismConfig] = None, ): self.rank_zero_seed = rank_zero_seed self.model = model @@ -540,9 +543,8 @@ def __init__( self.profiler: Optional[Profiler] = None self.deepspeed_config = deepspeed_config - parallelism_config = parallelism_config or {} - self.fsdp_config = parallelism_config.get('fsdp', None) - self.tp_config = parallelism_config.get('tp', None) + self.fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None + self.tp_config = parallelism_config.tp if parallelism_config is not None else None self._validate_parallelism_configs() @@ -552,9 +554,9 @@ def __init__( if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names: fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value) fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value) - self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore + self.fsdp_config.device_mesh = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore if self.tp_config is not None and self.device_mesh is not None: - self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value] + self.tp_config.device_mesh = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value] # Set defaults for transient variables (to make pyright happy) self.batch: Any = None @@ -598,11 +600,11 @@ def _validate_parallelism_configs(self): if self.fsdp_config is None: raise ValueError( 'Tensor parallelism (TP) currently requires FSDP to be enabled. ' - 'An empty `fsdp_config` can be specified to enable FSDP with ' - 'default settings. Additionally, PyTorch currently errors if FSDP ' + "An empty `parallelism_config['fsdp'] = {}` config can be specified to enable " + 'FSDP with default settings. Additionally, PyTorch currently errors if FSDP ' 'data_parallel_shard_degree is not at least 2.', ) - if not self.fsdp_config['use_orig_params']: + if not self.fsdp_config.use_orig_params: raise ValueError( 'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, ' 'which is the default and recommended setting.', @@ -614,10 +616,10 @@ def _validate_parallelism_configs(self): raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).') assert self.fsdp_config is not None error_message = '' - if self.fsdp_config['sync_module_states'] == False: + if self.fsdp_config.sync_module_states == False: error_message += textwrap.dedent( - "load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. " - "Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ", + "load_monolith_rank0_only requires parallelism_config['fsdp']['sync_module_states'] to be True. " + "Either set parallelism_config['fsdp']['sync_module_states'] = True or set load_monolith_rank0_only = False.", ) # Broadcast rank 0 meta check to all ranks so error can be raised on all ranks rank0_on_meta = 0 @@ -654,7 +656,7 @@ def _validate_parallelism_configs(self): textwrap.dedent( 'Saving metrics is not allowed with sharded state dict as metric tensors will ' 'be sharded and break on load. If you wish to save metric state, set ' - 'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.', + "parallelism_config['fsdp']['state_dict_type'] = 'full' to disable sharded checkpoints.", ), ) @@ -881,7 +883,7 @@ def fsdp_state_dict_type(self): if not self.fsdp_enabled: return None if self.fsdp_config is not None: - return self.fsdp_config['state_dict_type'] + return self.fsdp_config.state_dict_type return 'full' @property @@ -906,8 +908,8 @@ def load_fsdp_monolith_rank0_only(self): @property def load_monolith_rank0_only(self): return ( - self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and - self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True + self.fsdp_config is not None and self.fsdp_config.auto_wrap and + self.fsdp_config.state_dict_type == 'full' and self.fsdp_config.load_monolith_rank0_only == True ) def _get_integrations_state_dict(self) -> dict[str, Any]: diff --git a/composer/distributed/__init__.py b/composer/distributed/__init__.py index 9a994dd762..ccbf500f34 100644 --- a/composer/distributed/__init__.py +++ b/composer/distributed/__init__.py @@ -11,7 +11,6 @@ prepare_fsdp_module, prepare_tp_module, ) -from composer.distributed.mosaic_fsdp import set_fsdp_default __all__ = [ 'fix_batch_precision_for_deepspeed', @@ -21,5 +20,4 @@ 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module', - 'set_fsdp_default', ] diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index 606cf64010..ad08e172b8 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -24,14 +24,14 @@ from composer.core import Precision, State from composer.devices import Device from composer.distributed.meta_safe_apply import meta_safe_apply -from composer.distributed.mosaic_fsdp import ( +from composer.distributed.mosaic_parallelism import ( BACKWARD_PREFETCH_MAP, SHARDING_MAP, get_cpu_offload, get_mixed_precision, set_custom_fsdp_module_kwargs, ) -from composer.utils import StringEnum, dist, ensure_tuple +from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple __all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module'] @@ -181,24 +181,24 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info( def prepare_tp_module( model: torch.nn.Module, - tp_config: dict[str, Any], + tp_config: TPConfig, ) -> None: """Prepare a module (assumed ComposerModel) for use with tensor parallel.""" from torch.distributed.tensor.parallel import parallelize_module - device_mesh = tp_config['device_mesh'] - layer_plan = tp_config['layer_plan'] + device_mesh = tp_config.device_mesh + assert device_mesh is not None # For type checking, set in State.__init__ parallelize_module( module=model, device_mesh=device_mesh, - parallelize_plan=layer_plan, + parallelize_plan=tp_config.layer_plan, ) def prepare_fsdp_module( model: torch.nn.Module, optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]], - fsdp_config: dict[str, Any], + fsdp_config: FSDPConfig, precision: Precision, device: Device, auto_microbatching: bool, @@ -216,7 +216,7 @@ def prepare_fsdp_module( te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234. """ # Check sync_module_states is True for mixed initialization or HSDP - if fsdp_config['sync_module_states'] == False: + if fsdp_config.sync_module_states == False: rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) dist.all_reduce(all_ranks_meta, reduce_operation='MIN') @@ -226,7 +226,7 @@ def prepare_fsdp_module( raise ValueError( 'Detected mixed initialization where some ranks have model on cpu or ' 'gpu and some ranks are on meta. Either keep all ranks on the same ' - "device or set fsdp_config['sync_module_states'] = True. Otherwise, " + "device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.', ) @@ -263,7 +263,7 @@ def sync_hook(*args): num_param_groups = len(optim.param_groups) if num_param_groups > 1: - if not fsdp_config['use_orig_params']: + if not fsdp_config.use_orig_params: raise RuntimeError( 'Multiple optimizer groups with FSDP are only supported with ' 'use_orig_params=True.', @@ -297,17 +297,19 @@ def sync_hook(*args): optim.param_groups.clear() optim.state.clear() - sharding_map_key = fsdp_config['sharding_strategy'].upper() + sharding_map_key = fsdp_config.sharding_strategy.upper() sharding_strategy = SHARDING_MAP[sharding_map_key] kwargs = {} - if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0') and 'device_mesh' in fsdp_config: - if fsdp_config['process_group'] is not None: + if version.parse( + torch.__version__.split('.dev')[0], + ) >= version.parse('2.2.0') and fsdp_config.device_mesh is not None: + if fsdp_config.process_group is not None: warnings.warn( 'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.', ) else: - ndim = fsdp_config['device_mesh'].ndim + ndim = fsdp_config.device_mesh.ndim if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD: sharding_strategy = ShardingStrategy.FULL_SHARD warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.') @@ -320,12 +322,12 @@ def sync_hook(*args): elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD: sharding_strategy = ShardingStrategy.HYBRID_SHARD warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.') - kwargs['device_mesh'] = fsdp_config['device_mesh'] + kwargs['device_mesh'] = fsdp_config.device_mesh - cpu_offload = get_cpu_offload(cpu_offload=fsdp_config['cpu_offload']) + cpu_offload = get_cpu_offload(cpu_offload=fsdp_config.cpu_offload) - mixed_precision = fsdp_config['mixed_precision'] - keep_low_precision_grads = fsdp_config['keep_low_precision_grads'] + mixed_precision = fsdp_config.mixed_precision + keep_low_precision_grads = fsdp_config.keep_low_precision_grads mixed_precision, param_dtype, _, _ = get_mixed_precision( precision, mixed_precision=mixed_precision, @@ -357,22 +359,22 @@ def sync_hook(*args): ) process_group = None - if fsdp_config['process_group'] is not None: - process_group_dict = {'process_group': fsdp_config['process_group']} + if fsdp_config.process_group is not None: + process_group_dict = {'process_group': fsdp_config.process_group} process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group'] - backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()] - activation_checkpointing = fsdp_config['activation_checkpointing'] - activation_cpu_offload = fsdp_config['activation_cpu_offload'] - sync_module_states = fsdp_config['sync_module_states'] - forward_prefetch = fsdp_config['forward_prefetch'] - limit_all_gathers = fsdp_config['limit_all_gathers'] - ignored_modules = fsdp_config['ignored_modules'] - state_dict_type = fsdp_config['state_dict_type'] - activation_checkpointing_reentrant = fsdp_config['activation_checkpointing_reentrant'] - te_checkpoint_wrapper = fsdp_config['te_checkpoint_wrapper'] if precision == Precision.AMP_FP8 else False - te_shard_fp8_weight = fsdp_config['te_shard_fp8_weight'] if precision == Precision.AMP_FP8 else False - sharded_ckpt_prefix_dir = fsdp_config['sharded_ckpt_prefix_dir'] - use_orig_params = fsdp_config['use_orig_params'] + backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config.backward_prefetch.upper()] + activation_checkpointing = fsdp_config.activation_checkpointing + activation_cpu_offload = fsdp_config.activation_cpu_offload + sync_module_states = fsdp_config.sync_module_states + forward_prefetch = fsdp_config.forward_prefetch + limit_all_gathers = fsdp_config.limit_all_gathers + ignored_modules = fsdp_config.ignored_modules + state_dict_type = fsdp_config.state_dict_type + activation_checkpointing_reentrant = fsdp_config.activation_checkpointing_reentrant + te_checkpoint_wrapper = fsdp_config.te_checkpoint_wrapper if precision == Precision.AMP_FP8 else False + te_shard_fp8_weight = fsdp_config.te_shard_fp8_weight if precision == Precision.AMP_FP8 else False + sharded_ckpt_prefix_dir = fsdp_config.sharded_ckpt_prefix_dir + use_orig_params = fsdp_config.use_orig_params # We choose to not wrap the ComposerModel directly, but instead wrap any submodules like `ComposerModel.model` # This makes it safer to call ComposerModel-specific functions like 'eval_forward' that @@ -591,7 +593,7 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num if hasattr(fsdp_obj, '_exec_order_data'): if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'): - fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit'] + fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config.forward_prefetch_limit else: warnings.warn( 'FSDP._exec_order_data does not have attribute _forward_prefetch_limit ' @@ -599,7 +601,7 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num 'config being ignored. Please open an issue to Composer to report this.', ) if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'): - fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit'] + fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config.backward_prefetch_limit else: warnings.warn( 'FSDP._exec_order_data does not have attribute _backward_prefetch_limit ' @@ -712,7 +714,7 @@ def _check_fn(module: torch.nn.Module) -> bool: setattr(model, obj_name, fsdp_obj) # Print FSDP wrapped model and FSDP config if `verbose=True` - if fsdp_config['verbose']: + if fsdp_config.verbose: log.info(f'FSDP: Wrapped model: {model}') log.info(f'FSDP: Using sharding_strategy={sharding_strategy}') log.info(f'FSDP: Using cpu_offload={cpu_offload}') diff --git a/composer/distributed/mosaic_fsdp.py b/composer/distributed/mosaic_parallelism.py similarity index 72% rename from composer/distributed/mosaic_fsdp.py rename to composer/distributed/mosaic_parallelism.py index d59b2fe3b2..66c06d911b 100644 --- a/composer/distributed/mosaic_fsdp.py +++ b/composer/distributed/mosaic_parallelism.py @@ -19,7 +19,7 @@ ) from composer.core import Precision -from composer.utils import VersionedDeprecationWarning, dist +from composer.utils import dist log = logging.getLogger(__name__) @@ -40,69 +40,6 @@ } -def set_fsdp_default(fsdp_config: dict[str, Any]): - """Modify fsdp_config to set default values for missing keys.""" - if 'process_group' in fsdp_config: - warnings.warn( - VersionedDeprecationWarning( - 'process_group is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', - remove_version='0.24', - ), - ) - - if 'device_mesh' in fsdp_config: - warnings.warn( - VersionedDeprecationWarning( - 'device_mesh is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', - remove_version='0.24', - ), - ) - if 'data_parallel_shard_degree' in fsdp_config or 'data_parallel_replicate_degree' in fsdp_config: - raise ValueError( - 'Cannot specify both `device_mesh` and `data_parallel_shard_degree` or `data_parallel_replicate_degree`. Please remove `device_mesh`.', - ) - device_mesh = fsdp_config.pop('device_mesh') - if len(device_mesh) == 1: - fsdp_config['data_parallel_shard_degree'] = device_mesh[0] - elif len(device_mesh) == 2: - fsdp_config['data_parallel_replicate_degree'] = device_mesh[0] - fsdp_config['data_parallel_shard_degree'] = device_mesh[1] - else: - raise ValueError( - f'device_mesh must be of length 1 or 2 but received length {len(device_mesh)} with device mesh {device_mesh}.', - ) - - fsdp_config.setdefault('activation_checkpointing', False) - fsdp_config.setdefault('activation_checkpointing_reentrant', True) - fsdp_config.setdefault('activation_cpu_offload', False) - fsdp_config.setdefault('auto_wrap', True) - fsdp_config.setdefault('te_checkpoint_wrapper', False) - fsdp_config.setdefault('te_shard_fp8_weight', False) - fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST') - fsdp_config.setdefault('backward_prefetch_limit', 1) - fsdp_config.setdefault('cpu_offload', False) - fsdp_config.setdefault('data_parallel_shard_degree', -1) - fsdp_config.setdefault('data_parallel_replicate_degree', None) - fsdp_config.setdefault('forward_prefetch', False) - fsdp_config.setdefault('forward_prefetch_limit', 1) - fsdp_config.setdefault('ignored_modules', None) - fsdp_config.setdefault('keep_low_precision_grads', False) - fsdp_config.setdefault('limit_all_gathers', True) - fsdp_config.setdefault('load_monolith_rank0_only', False) - fsdp_config.setdefault('load_planner', None) - fsdp_config.setdefault('mixed_precision', 'DEFAULT') - fsdp_config.setdefault('process_group', None) - fsdp_config.setdefault('save_planner', None) - fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}') - fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD') - fsdp_config.setdefault('state_dict_type', 'full') - fsdp_config.setdefault('sync_module_states', False) - fsdp_config.setdefault('use_orig_params', True) - fsdp_config.setdefault('verbose', False) - - return fsdp_config - - def _get_torch_dtype(dtype: Union[Precision, str]): """Convert common string representations of dtypes to torch dtypes.""" dtype = dtype.value if isinstance(dtype, Precision) else dtype diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 353cd97258..eb5080eaee 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -84,7 +84,6 @@ prepare_ddp_module, prepare_fsdp_module, prepare_tp_module, - set_fsdp_default, ) from composer.loggers import ( ConsoleLogger, @@ -104,12 +103,18 @@ from composer.trainer._scale_schedule import scale_pytorch_scheduler from composer.trainer._scaler import ClosureGradScaler from composer.utils import ( + MLFLOW_EXPERIMENT_ID_FORMAT_KEY, + MLFLOW_RUN_ID_FORMAT_KEY, ExportFormat, + FSDPConfig, MissingConditionalImportError, ObjectStore, + ParallelismConfig, + TPConfig, Transform, VersionedDeprecationWarning, checkpoint, + create_fsdp_config, dist, ensure_tuple, export_with_logger, @@ -118,6 +123,7 @@ get_composer_env_dict, get_device, get_file, + is_model_deepspeed, is_xla_installed, map_collection, maybe_create_object_store_from_uri, @@ -127,8 +133,6 @@ partial_format, reproducibility, ) -from composer.utils.misc import is_model_deepspeed -from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY if is_xla_installed(): import torch_xla.core.xla_model as xm @@ -912,7 +916,7 @@ class Trainer: disable FSDP, set to ``None``. (default: ``None``) fsdp_auto_wrap (bool, optional): option to let trainer wrap the module, or if the module is already wrapped outside, allow the user to disable auto-wrapping. - parallelism_config (dict[str, Any], optional): Configuration for parallelism options. + parallelism_config (Union[dict[str, Any], ParallelismConfig], optional): Configuration for parallelism options. Currently supports fsdp and tensor parallelism, whose respective configs are specified as the keys ``fsdp`` and ``tp``. (default: ``None``) @@ -1069,7 +1073,7 @@ def __init__( deepspeed_config: Optional[dict[str, Any]] = None, fsdp_config: Optional[dict[str, Any]] = None, fsdp_auto_wrap: bool = True, - parallelism_config: Optional[dict[str, Any]] = None, + parallelism_config: Optional[Union[dict[str, Any], ParallelismConfig]] = None, # System/Numerics device: Optional[Union[str, Device]] = None, @@ -1185,7 +1189,12 @@ def __init__( ) if parallelism_config is None: parallelism_config = {} - if parallelism_config.get('fsdp') is not None: + if isinstance(parallelism_config, ParallelismConfig): + raise ValueError( + 'fsdp_config cannot be specified if parallelism_config is a ParallelismConfig object. ' + 'Please instead pass fsdp_config as a FSDPConfig object when constructing ParallelismConfig.', + ) + elif parallelism_config.get('fsdp') is not None: raise ValueError( 'fsdp_config is specified in both fsdp_config and parallelism_config. Please specify it in only in parallelism_config.', ) @@ -1199,22 +1208,30 @@ def __init__( ) if parallelism_config is None: parallelism_config = {} - if parallelism_config.get('fsdp') is None: - parallelism_config['fsdp'] = {} - parallelism_config['fsdp']['auto_wrap'] = fsdp_auto_wrap - if parallelism_config is not None: - # Set defaults and create shallow copies of configs to avoid changing user's config - parallelism_config = {**parallelism_config} - if parallelism_config.get('fsdp', None) is not None: - parallelism_config['fsdp'] = set_fsdp_default({**parallelism_config['fsdp']}) - if parallelism_config.get('tp', None) is not None: - parallelism_config['tp'] = {**parallelism_config['tp']} - # Remove empty configs - for key in list(parallelism_config.keys()): - if parallelism_config[key] == None: - del parallelism_config[key] - if len(parallelism_config) == 0: - parallelism_config = None + if isinstance(parallelism_config, ParallelismConfig): + raise ValueError( + 'fsdp_auto_wrap cannot be specified if parallelism_config is a ParallelismConfig object. ' + 'Please instead pass fsdp_auto_wrap to FSDPConfig as part of ParallelismConfig.', + ) + else: + if parallelism_config.get('fsdp') is None: + parallelism_config['fsdp'] = {} + parallelism_config['fsdp']['auto_wrap'] = fsdp_auto_wrap + if parallelism_config is not None and not isinstance(parallelism_config, ParallelismConfig): + parallelism_config_args = {} + if 'fsdp' in parallelism_config and parallelism_config['fsdp'] is not None: + if isinstance(parallelism_config['fsdp'], FSDPConfig): + parallelism_config_args['fsdp'] = parallelism_config['fsdp'] + else: + parallelism_config_args['fsdp'] = create_fsdp_config(parallelism_config['fsdp']) + if 'tp' in parallelism_config and parallelism_config['tp'] is not None: + if isinstance(parallelism_config['tp'], TPConfig): + parallelism_config_args['tp'] = parallelism_config['tp'] + else: + parallelism_config['tp'] = TPConfig(**parallelism_config['tp']) + parallelism_config = ParallelismConfig( + **parallelism_config_args, + ) if len(parallelism_config_args) > 0 else None if deepspeed_config is not None and parallelism_config is not None: raise ValueError( 'Both deepspeed_config and parallelism_config are specified but incompatible. Please specify only one.', @@ -1650,8 +1667,7 @@ def __init__( ) # FSDP wrap if not using monolith checkpoint on rank 0 only - if self.state.fsdp_config is not None and self.state.fsdp_config['auto_wrap' - ] and not self.state.load_monolith_rank0_only: + if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only: with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module( model, @@ -1823,8 +1839,8 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_monolith_rank0_only=True but no checkpoint was loaded. if ( - not self.state.fsdp_enabled and self.state.fsdp_config is not None and - self.state.fsdp_config['auto_wrap'] and self.state.load_monolith_rank0_only + not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and + self.state.load_monolith_rank0_only ): with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 7829c9fe76..9618d5f837 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -61,6 +61,8 @@ partial_format, ) from composer.utils.object_store import ( + MLFLOW_EXPERIMENT_ID_FORMAT_KEY, + MLFLOW_RUN_ID_FORMAT_KEY, GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, @@ -71,6 +73,7 @@ SFTPObjectStore, UCObjectStore, ) +from composer.utils.parallelism import FSDPConfig, ParallelismConfig, TPConfig, create_fsdp_config from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum from composer.utils.warnings import VersionedDeprecationWarning @@ -147,4 +150,10 @@ 'KNOWN_COMPRESSORS', 'STR_TO_DTYPE', 'ParallelismType', + 'create_fsdp_config', + 'FSDPConfig', + 'TPConfig', + 'ParallelismConfig', + 'MLFLOW_EXPERIMENT_ID_FORMAT_KEY', + 'MLFLOW_RUN_ID_FORMAT_KEY', ] diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 2d176b135a..f2342eeb4c 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -685,7 +685,7 @@ def load_sharded_checkpoint( dist_cp.load_state_dict( state_dict=state_dict, storage_reader=storage_reader, - planner=state.fsdp_config['load_planner'], + planner=state.fsdp_config.load_planner, no_dist=(not dist.is_initialized()), ) @@ -1053,7 +1053,7 @@ def get_save_filename( # Sharded checkpoints get their own little folder. assert state.fsdp_config is not None - remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir'] + remote_prefix = state.fsdp_config.sharded_ckpt_prefix_dir assert remote_prefix is not None save_dirpath = Path(Path(filename).parent) / Path(remote_prefix) save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp) @@ -1145,7 +1145,7 @@ def _save_checkpoint( if expect_file: if version.parse(torch.__version__) >= version.parse('2.3.0'): - save_planner = state.fsdp_config['save_planner'] + save_planner = state.fsdp_config.save_planner if save_planner is None: from composer.trainer._patch_pytorch import SavePlannerWithDedupFix @@ -1160,7 +1160,7 @@ def _save_checkpoint( dist_cp.save_state_dict( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), - planner=state.fsdp_config['save_planner'], + planner=state.fsdp_config.save_planner, process_group=process_group, ) log.debug('Finished pytorch save state dict') diff --git a/composer/utils/object_store/__init__.py b/composer/utils/object_store/__init__.py index e623c385f0..3c70257e08 100644 --- a/composer/utils/object_store/__init__.py +++ b/composer/utils/object_store/__init__.py @@ -5,7 +5,11 @@ from composer.utils.object_store.gcs_object_store import GCSObjectStore from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore -from composer.utils.object_store.mlflow_object_store import MLFlowObjectStore +from composer.utils.object_store.mlflow_object_store import ( + MLFLOW_EXPERIMENT_ID_FORMAT_KEY, + MLFLOW_RUN_ID_FORMAT_KEY, + MLFlowObjectStore, +) from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError from composer.utils.object_store.oci_object_store import OCIObjectStore from composer.utils.object_store.s3_object_store import S3ObjectStore @@ -22,4 +26,6 @@ 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore', + 'MLFLOW_EXPERIMENT_ID_FORMAT_KEY', + 'MLFLOW_RUN_ID_FORMAT_KEY', ] diff --git a/composer/utils/parallelism.py b/composer/utils/parallelism.py new file mode 100644 index 0000000000..4dc921b63a --- /dev/null +++ b/composer/utils/parallelism.py @@ -0,0 +1,96 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Parallelism configs.""" + +import warnings +from dataclasses import dataclass +from typing import Any, Optional + +from torch.distributed._tensor.device_mesh import DeviceMesh + +from composer.utils.warnings import VersionedDeprecationWarning + + +@dataclass +class FSDPConfig: + """Configuration for Fully Sharded Data Parallelism (FSDP).""" + activation_checkpointing: bool = False + activation_checkpointing_reentrant: bool = True + activation_cpu_offload: bool = False + auto_wrap: bool = True + te_checkpoint_wrapper: bool = False + te_shard_fp8_weight: bool = False + backward_prefetch: str = 'BACKWARD_POST' + backward_prefetch_limit: int = 1 + cpu_offload: bool = False + data_parallel_shard_degree: int = -1 + data_parallel_replicate_degree: Optional[int] = None + device_mesh: Optional[DeviceMesh] = None + forward_prefetch: bool = False + forward_prefetch_limit: int = 1 + ignored_modules: Optional[Any] = None + keep_low_precision_grads: bool = False + limit_all_gathers: bool = True + load_monolith_rank0_only: bool = False + load_planner: Optional[Any] = None + mixed_precision: str = 'DEFAULT' + process_group: Optional[Any] = None + save_planner: Optional[Any] = None + sharded_ckpt_prefix_dir: str = 'ep{epoch}-ba{batch}' + sharding_strategy: str = 'FULL_SHARD' + state_dict_type: str = 'full' + sync_module_states: bool = False + use_orig_params: bool = True + verbose: bool = False + + +def create_fsdp_config(fsdp_config: dict[str, Any]): + """Modify fsdp_config to set default values for missing keys.""" + fsdp_config = {**fsdp_config} # Shallow copy to avoid modifying input + if 'process_group' in fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'process_group is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', + remove_version='0.24', + ), + ) + + if 'device_mesh' in fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'device_mesh is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', + remove_version='0.24', + ), + ) + if 'data_parallel_shard_degree' in fsdp_config or 'data_parallel_replicate_degree' in fsdp_config: + raise ValueError( + 'Cannot specify both `device_mesh` and `data_parallel_shard_degree` or `data_parallel_replicate_degree`. Please remove `device_mesh`.', + ) + device_mesh = fsdp_config.pop('device_mesh') + if len(device_mesh) == 1: + fsdp_config['data_parallel_shard_degree'] = device_mesh[0] + elif len(device_mesh) == 2: + fsdp_config['data_parallel_replicate_degree'] = device_mesh[0] + fsdp_config['data_parallel_shard_degree'] = device_mesh[1] + else: + raise ValueError( + f'device_mesh must be of length 1 or 2 but received length {len(device_mesh)} with device mesh {device_mesh}.', + ) + + return FSDPConfig(**fsdp_config) + + +@dataclass +class TPConfig: + """Configuration for tensor parallelism (TP).""" + device_mesh: Optional[DeviceMesh] = None + tensor_parallel_degree: int = 1 + layer_plan: Any = None + + +@dataclass +class ParallelismConfig: + """Configuration for parallelism.""" + fsdp: Optional[FSDPConfig] = None + tp: Optional[TPConfig] = None diff --git a/tests/test_events.py b/tests/test_events.py index 633457ff57..235d0941f1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -112,12 +112,8 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p parallelism_config = { 'fsdp': { 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_ocpu_offload': False, - 'verbose': False, }, } diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 109ff63901..30ec369dc6 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -30,7 +30,7 @@ from composer.models import ComposerClassifier from composer.optim import DecoupledAdamW from composer.trainer import Trainer -from composer.utils import dist, parse_uri +from composer.utils import FSDPConfig, dist, parse_uri from composer.utils.checkpoint import is_checkpoint_legacy_sharded from composer.utils.file_helpers import get_file from composer.utils.object_store import S3ObjectStore @@ -79,20 +79,6 @@ def param_init_fn(self, module): torch.nn.init.zeros_(module.bias) -@dataclasses.dataclass(frozen=True) -class FSDPConfig: - state_dict_type: str = 'full' - sharding_strategy: str = 'FULL_SHARD' - sharded_ckpt_prefix_dir: str = 'ba{batch}' - sync_module_states: bool = True - use_orig_params: bool = True - load_monolith_rank0_only: bool = False - save_planner: Optional[Any] = None - load_planner: Optional[Any] = None - data_parallel_shard_degree: int = -1 - process_group: Optional[str] = None - - def get_trainer( model_init_device: str = 'cpu', save_folder: Optional[str] = None, @@ -118,7 +104,7 @@ def get_trainer( tp_config: Optional[dict[str, Any]] = None, ): if fsdp_config is None: - fsdp_config = FSDPConfig() + fsdp_config = FSDPConfig(sharded_ckpt_prefix_dir='ba{batch}') model = SimpleMLP( num_features=num_features, num_classes=num_classes, @@ -139,7 +125,7 @@ def get_trainer( else: raise ValueError(f'Unsupported optimizer name {optimizer}') - parallelism_config = {'fsdp': dataclasses.asdict(fsdp_config)} + parallelism_config: dict[str, Union[FSDPConfig, dict[str, Any]]] = {'fsdp': fsdp_config} if tp_config is not None: parallelism_config['tp'] = tp_config @@ -336,7 +322,11 @@ def test_fsdp_full_state_dict_load( save_folder = tmp_path save_filename = 'rank{rank}.pt' - fsdp_config = FSDPConfig(load_monolith_rank0_only=load_monolith_rank0_only) + fsdp_config = FSDPConfig( + sharded_ckpt_prefix_dir='ba{batch}', + sync_module_states=load_monolith_rank0_only, + load_monolith_rank0_only=load_monolith_rank0_only, + ) tp_config = None if use_tp: from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -413,7 +403,10 @@ def test_fsdp_mixed_with_sync( get_trainer( model_init_device=['cpu', 'meta'][dist.get_global_rank()], save_folder=str(tmp_path), - fsdp_config=FSDPConfig(sync_module_states=sync_module_states), + fsdp_config=FSDPConfig( + sync_module_states=sync_module_states, + sharded_ckpt_prefix_dir='ba{batch}', + ), ) @@ -567,6 +560,7 @@ def test_fsdp_load_old_checkpoint( state_dict_type=state_dict_type, sharding_strategy=sharding_strategy, process_group='mod1' if requires_pgs else None, + sharded_ckpt_prefix_dir='ba{batch}', ) trainer = get_trainer( @@ -689,7 +683,10 @@ def test_fsdp_full_state_dict_load_with_ema( save_folder = tmp_path save_filename = 'ba{batch}-rank{rank}.pt' - fsdp_config = FSDPConfig(sharding_strategy='SHARD_GRAD_OP') + fsdp_config = FSDPConfig( + sharding_strategy='SHARD_GRAD_OP', + sharded_ckpt_prefix_dir='ba{batch}', + ) trainer1 = get_trainer( save_folder=str(save_folder), @@ -749,7 +746,10 @@ def mock_get_checkpoint_validation_function(): tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) save_folder = os.path.join(tmp_paths[0], 'checkpoints') - fsdp_config = FSDPConfig(state_dict_type=state_dict_type) + fsdp_config = FSDPConfig( + state_dict_type=state_dict_type, + sharded_ckpt_prefix_dir='ba{batch}', + ) # First trainer saves checkpoints. trainer = get_trainer(save_folder=save_folder, fsdp_config=fsdp_config, max_duration='1ba') @@ -830,10 +830,10 @@ def test_fsdp_partitioned_state_dict_load( save_filename = 'ba{batch}-rank{rank}.pt' - fsdp_config = FSDPConfig(state_dict_type='sharded') + fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') tp_config = None if use_tp: - fsdp_config = FSDPConfig(state_dict_type='sharded') + fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel tp_config = { 'tensor_parallel_degree': 2, @@ -987,7 +987,7 @@ def test_elastic_resumption( run_name=run_name, max_duration='4ba', load_weights_only=False, - fsdp_config=FSDPConfig(state_dict_type='sharded'), + fsdp_config=FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}'), ) def get_mono_state_dict_from_sharded_one(trainer): @@ -1059,7 +1059,7 @@ def test_cleanup_sharded_checkpoints( max_duration=f'{batches_to_train}ba', save_interval='1ba', save_num_checkpoints_to_keep=num_ckpts_to_keep, - fsdp_config=FSDPConfig(state_dict_type='sharded'), + fsdp_config=FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}'), ) run_name = trainer1.state.run_name trainer1.fit() @@ -1149,6 +1149,7 @@ def set_up_planner( state_dict_type='sharded', load_planner=load_planner, save_planner=save_planner, + sharded_ckpt_prefix_dir='ba{batch}', ) trainer1 = get_trainer( @@ -1232,6 +1233,7 @@ def test_fsdp_monolith_resumption( use_orig_params=use_orig_params, sync_module_states=sync_module_states, state_dict_type='full', + sharded_ckpt_prefix_dir='ba{batch}', ) # All ranks use rank 0 folder