diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 8265d51d17..fecdb62c93 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -20,31 +20,35 @@ def patch_pytorch(): raise NotImplementedError(f'Not supported for torch < 1.13.1') elif version.parse(torch.__version__) < version.parse('2.0.0'): - # FullyShardedDataParallel monkey path for torch < 2.0 ie torch == 1.13.1 + # Monkey patch for torch < 2.0 ie torch == 1.13.1 - # monkey patch _auto_wrap with _custom_auto_wrap fn + # Monkey patch _auto_wrap with _custom_auto_wrap fn FullyShardedDataParallel._auto_wrap = custom_auto_wrap_t1p13p1 # type: ignore elif version.parse(torch.__version__) < version.parse('2.0.1'): raise NotImplementedError(f'Not supported for torch == 2.0.0') - elif version.parse(torch.__version__) == version.parse('2.0.1'): + elif version.parse(torch.__version__) < version.parse('2.0.2'): # Monkey patch for torch == 2.0.1 # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1 - FullyShardedDataParallel.__init__ = init_fn_t2p0p1 + FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata ChunkShardingSpec.shard = shard - elif version.parse(torch.__version__) < version.parse('2.2.0'): - # Monkey path for torch < 2.2.0 ie torch == 2.1.0 + elif version.parse(torch.__version__) < version.parse('2.1.1'): + # Monkey path for torch < 2.1.1 ie torch == 2.1.0 + + # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p1p0 + FullyShardedDataParallel.__init__ = init_fn_t2p1p0 # type: ignore # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata ChunkShardingSpec.shard = shard - elif version.parse(torch.__version__) >= version.parse('2.2.0'): - raise NotImplementedError(f'Not supported for torch >= 2.2.0') + elif version.parse(torch.__version__) >= version.parse('2.1.1'): + raise NotImplementedError(f'FullyShardedDataParallel is not supported for torch >= 2.2.0') diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 0924f9ec46..090d92e227 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -7,7 +7,9 @@ """Utilities for monkey patching FSDP.""" import functools +import inspect import warnings +from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast import torch @@ -277,6 +279,8 @@ def _custom_recursive_wrap_t1p13p1( remainder = num_params - total_wrapped_params module_kwargs = auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder) if not only_wrap_children and module_kwargs: + # CHANGE: We modify the original code to support custom FSDP kwargs and add + # the process_group_cache to avoid instantiating a new process group. module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {} module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache) @@ -334,6 +338,7 @@ def custom_auto_wrap_t1p13p1( 'instances with mixed precision disabled since some batch norm ' 'kernels do not support low precision.') auto_wrap_kwargs['auto_wrap_policy'] = auto_wrap_policy + # CHANGE: Add process group cache and call our custom _recursive_wrap auto_wrap_kwargs['process_group_cache'] = {} _custom_recursive_wrap_t1p13p1(**auto_wrap_kwargs, **fsdp_kwargs) @@ -420,6 +425,8 @@ def _custom_recursive_wrap_t2p0p1( remainder = nonwrapped_numel - total_wrapped_numel module_kwargs = auto_wrap_policy(module=module, recurse=False, nonwrapped_numel=remainder) if not only_wrap_children and module_kwargs: + # CHANGE: We modify the original code to support custom FSDP kwargs and add + # the process_group_cache to avoid instantiating a new process group. module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {} module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache) @@ -488,11 +495,14 @@ def _custom_auto_wrap_t2p0p1( 'instances with mixed precision disabled since some batch norm ' 'kernels do not support low precision.') auto_wrap_kwargs['auto_wrap_policy'] = auto_wrap_policy + + # CHANGE: Add process group cache and call our custom _recursive_wrap auto_wrap_kwargs['process_group_cache'] = {} _custom_recursive_wrap_t2p0p1(**auto_wrap_kwargs, **fsdp_kwargs) -if version.parse(torch.__version__) == version.parse('2.0.1'): +if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( + torch.__version__) < version.parse('2.0.2'): from torch.distributed.fsdp._init_utils import ProcessGroupType from torch.distributed.fsdp.wrap import _FSDPPolicy @@ -571,7 +581,7 @@ def init_fn_t2p0p1( # process groups. fsdp_kwargs['process_group'] = (self.process_group, self._inter_node_pg) - # call the custom _auto_wrap function + # CHANGE: Call our custom _auto_wrap function _custom_auto_wrap_t2p0p1(auto_wrap_kwargs, fsdp_kwargs, FullyShardedDataParallel) backward_prefetch_limit = 1 @@ -608,6 +618,304 @@ def init_fn_t2p0p1( _register_all_state_dict_hooks(self) +def _custom_recursive_wrap_t2p1p0( + module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + process_group_cache: Dict[Tuple[int], Any], + only_wrap_children: bool = False, + **kwargs: Any, +) -> Tuple[nn.Module, int]: + """Supports custom wrapping of modules with FSDP kwargs. + + Torch version must be 2.1.0. + + Modified version of https://github.com/pytorch/pytorch/blob/8292b03c47fd71beb23ae834971e044aef6f4d7c/torch/distributed/fsdp/_wrap_utils.py#L25 + to support custom FSDP kwargs, e.g. process groups. + + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + + Args: + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. + wrapper_cls: wrapper_cls + ignored_modules (Set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + process_group_cache (Dict[Tuple[int], Any]): a cache of process_group to + use instead of potentially instantiating a new process_group + only_wrap_children: warp only children + Returns: + (nn.Module, int): + ``module`` after wrapping and the numel recursively wrapped. + """ + from torch.distributed.fsdp.wrap import _wrap + + assert auto_wrap_policy is not None, 'Must specify auto_wrap_policy.' + assert wrapper_cls is not None, 'Must specify wrapper_cls' + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + assert not isinstance(child, cast(type, wrapper_cls)) + except TypeError: + # wrapper_cls is a function as opposed to a class type so we bypass the above check. + pass + + # We count all params, assuming none of them are already wrapped. + nonwrapped_numel = sum(p.numel() for p in module.parameters() if p not in ignored_params) + + assert auto_wrap_policy is not None + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = _custom_recursive_wrap_t2p1p0( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + process_group_cache=process_group_cache, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_numel += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = nonwrapped_numel - total_wrapped_numel + + module_kwargs = auto_wrap_policy(module=module, recurse=False, nonwrapped_numel=remainder) + if not only_wrap_children and module_kwargs: + # CHANGE: We modify the original code to support custom FSDP kwargs and add + # the process_group_cache to avoid instantiating a new process group. + module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {} + module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache) + + final_kwargs = {**kwargs, **module_kwargs} + + if final_kwargs.get('process_group', None) is not None: + _pg_ranks = distributed.get_process_group_ranks(final_kwargs['process_group']) + _meta_init = any(p.device.type == 'meta' for p in module.parameters()) + if _meta_init and len(_pg_ranks) != dist.get_world_size() and final_kwargs.get('use_orig_params'): + raise NotImplementedError( + f'FSDP with custom process groups cannot use `use_orig_params: True` when using meta init.') + + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel + else: + return module, total_wrapped_numel + return module, 0 + + +if version.parse(torch.__version__) > version.parse('2.0.2') and version.parse( + torch.__version__) < version.parse('2.1.1'): + from torch.distributed.fsdp._init_utils import ProcessGroupType + from torch.distributed.fsdp.wrap import ModuleWrapPolicy, _Policy + + def _custom_auto_wrap_t2p1p0( + root_module: nn.Module, + policy: Union[Callable, _Policy], + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + root_kwargs: Dict[str, Any], + fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` + ): + """Modified version of https://github.com/pytorch/pytorch/blob/f13101640f548f8fa139c03dfa6711677278c391/torch/distributed/fsdp/wrap.py#L487. + + Calls custom _recursive_wrap fn and adds progress group cache. + + Auto wraps modules in ``root_module`` 's tree according to ``policy`` + following a post-order traversal. + + Precondition: ``root_kwargs`` should contain all arguments except + ``module``. This function accepts the kwargs dict directly since it gets + forwarded into the post-order traversal function. + """ + from torch.distributed.fsdp._common_utils import _override_module_mixed_precision + from torch.distributed.fsdp._wrap_utils import (_check_nested_wrapping, _validate_frozen_params, + _warn_on_overridden_mixed_precision) + from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _or_policy, _post_order_apply, + _run_mixed_precision_override_policy, _wrap_module_cls_individually) + + mixed_precision = root_kwargs['mixed_precision'] + is_wrapper = inspect.isclass(fsdp_fn) + # TODO: We may relax this no-nested-wrapping constraint to support manual + # wrapping followed by auto wrapping. + _check_nested_wrapping(root_module) + + if isinstance(policy, _Policy): + root_kwargs['auto_wrap_policy' if is_wrapper else 'policy'] = None + target_module_to_kwargs = policy._run_policy(root_module, ignored_modules, root_kwargs) + if mixed_precision is not None: + target_module_to_kwargs = _run_mixed_precision_override_policy( + root_module, + mixed_precision._module_classes_to_ignore, + ignored_modules, + root_kwargs, + target_module_to_kwargs, + ) + overridden_module_classes = _override_module_mixed_precision(root_module, + mixed_precision._module_classes_to_ignore) + _warn_on_overridden_mixed_precision(overridden_module_classes) + use_orig_params = root_kwargs.get('use_orig_params', False) + _validate_frozen_params( + root_module, + set(target_module_to_kwargs.keys()), + ignored_params, + use_orig_params, + ) + wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) + _post_order_apply(root_module, wrap_fn) + return + + recursive_wrap_kwargs = { + 'module': root_module, + 'auto_wrap_policy': policy, + 'wrapper_cls': fsdp_fn, + 'ignored_modules': ignored_modules, + 'ignored_params': ignored_params, + 'only_wrap_children': True, + } + if mixed_precision is not None: + # Wrap modules of the ignored types separately and register forward + # hooks to cast to fp32 and back to the original dtype, respectively + overridden_module_classes = _override_module_mixed_precision(root_module, + mixed_precision._module_classes_to_ignore) + policy = functools.partial( + _or_policy, + policies=[ + policy, + partial( + _wrap_module_cls_individually, + module_classes=mixed_precision._module_classes_to_ignore, + ), + ], + ) + recursive_wrap_kwargs['auto_wrap_policy'] = policy + _warn_on_overridden_mixed_precision(overridden_module_classes) + + # CHANGE: Add process group cache and call our custom _recursive_wrap + recursive_wrap_kwargs['process_group_cache'] = {} + + _custom_recursive_wrap_t2p1p0(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] + + def init_fn_t2p1p0( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy]] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None, + ): + """Modified version of https://github.com/pytorch/pytorch/blob/8ed169b1628285924e10fc98de53dbb75c92c43e/torch/distributed/fsdp/fully_sharded_data_parallel.py#L399C1.""" + from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo + from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, _check_orig_params_flattened, + _init_buffer_state, _init_core_state, _init_device_handle, + _init_ignored_module_states, _init_param_handle_from_module, + _init_prefetching_state, _init_process_group_state, + _init_runtime_state, _init_state_dict_state) + from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks + from torch.distributed.fsdp._unshard_param_utils import _register_flat_param + + torch._C._log_api_usage_once('torch.distributed.fsdp') + super(FullyShardedDataParallel, self).__init__() + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + _init_process_group_state(self, process_group, sharding_strategy, auto_wrap_policy) + if auto_wrap_policy is not None: + root_kwargs = { + 'process_group': process_group, + 'sharding_strategy': sharding_strategy, + 'cpu_offload': cpu_offload, + 'backward_prefetch': backward_prefetch, + 'mixed_precision': mixed_precision, + 'param_init_fn': param_init_fn, + 'device_id': device_id, + 'sync_module_states': sync_module_states, + 'forward_prefetch': forward_prefetch, + 'limit_all_gathers': limit_all_gathers, + 'use_orig_params': use_orig_params, + 'ignored_states': self._ignored_params, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs['process_group'] = (self.process_group, self._inter_node_pg) + + # CHANGE: Call our custom _auto_wrap function + _custom_auto_wrap_t2p1p0( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + + def get_split_size(dim_size: int, chunks: int) -> int: """Gets the minimum size per chunk. diff --git a/pyproject.toml b/pyproject.toml index edc3544198..c95c22aff2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ include = [ exclude = [ "build/**", "node_modules/**", + 'composer/trainer/mosaic_fsdp_utils.py' ] # Disable checks for missing imports, as a conditional install of composer will not include them # Any incorrect imports will be discovered through test cases