diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 9eab9c365e..966ce8ba3d 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -340,11 +340,7 @@ def sync_hook(*args): f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` ' f'with sharding strategy `{sharding_map_key}.`') - if fsdp_config.get('min_params') is not None: - warnings.warn(DeprecationWarning('`min_params` in FSDP config will be deprecated in composer version 0.16.0.')) - backward_prefetch = backward_prefetch_map[fsdp_config['backward_prefetch'].upper()] - min_params = int(float(fsdp_config.get('min_params', 1e9))) activation_checkpointing = fsdp_config['activation_checkpointing'] activation_cpu_offload = fsdp_config['activation_cpu_offload'] sync_module_states = fsdp_config['sync_module_states'] @@ -441,20 +437,15 @@ def _param_init_fn(module: torch.nn.Module) -> None: # Choose which modules to FSDP wrap according to the following priority: # If module has attribute `module._fsdp_wrap = ...`, always respect it - # Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true - # Or if unwrapped params in module in greater than or equal to fsdp_config.min_params + # Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true. def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: if recurse: return True should_be_wrapped = False if hasattr(module, '_fsdp_wrap'): should_be_wrapped = bool(module._fsdp_wrap) - else: - is_large = nonwrapped_numel >= min_params - if hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): - should_be_wrapped = obj.fsdp_wrap_fn(module) or is_large - else: - should_be_wrapped = is_large + elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): + should_be_wrapped = obj.fsdp_wrap_fn(module) if should_be_wrapped and auto_microbatching: module.register_forward_hook(sync_hook) @@ -540,7 +531,6 @@ def _check_fn(module: torch.nn.Module) -> bool: print(f'FSDP: Using cpu_offload={cpu_offload}') print(f'FSDP: Using mixed_precision={mixed_precision}') print(f'FSDP: Using backward_prefetch={backward_prefetch}') - print(f'FSDP: Using min_params={min_params}') print(f'FSDP: Using activation_checkpointing={activation_checkpointing}') print(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}') print(f'FSDP: Using sync_module_states={sync_module_states}') diff --git a/docs/source/notes/distributed_training.rst b/docs/source/notes/distributed_training.rst index f1675a8a77..756f933b53 100644 --- a/docs/source/notes/distributed_training.rst +++ b/docs/source/notes/distributed_training.rst @@ -184,7 +184,6 @@ The full spec and defaults for Composer's `fsdp_config` is here: fsdp_config = { 'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD' - 'min_params': float # Default: 1e8 'cpu_offload': bool = True | False, # Default: False, cpu_offload not supported yet 'mixed_precision': str = 'FULL' | 'DEFAULT' | 'PURE', # Default: 'DEFAULT' # Note: you can explicitly provide a dictionary too @@ -279,7 +278,6 @@ An example code snippet for using FSDP with composer is provided below: fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, # Not supported yet 'mixed_precision': 'DEFAULT', 'backward_prefetch': 'BACKWARD_POST', @@ -310,9 +308,8 @@ To make auto-wrapping easier on users, Composer uses a custom auto wrap policy t 1) If any module is attributed with :code:`module._fsdp_wrap = True | False`, that choice will be respected. 2) If the root module (e.g. `GPT`) defines a function :code:`def fsdp_wrap_fn(module: torch.nn.Module) -> bool`, then that function will be used to evaluate the root module's children. -3) If any module has more parameters than :code:`fsdp_config['min_params']`, it will be wrapped. -These rules are meant to make it easy for users to modify existing models for usage with FSDP. You can either add attributes to modules you want to wrap (#1), define a filter (#2), or make no changes at all and just use the size-based policy via :code:`fsdp_config['min_params'] = ...` (#3). +These rules are meant to make it easy for users to modify existing models for usage with FSDP. You can either add attributes to modules you want to wrap (#1) or define a filter (#2). In `gpt.py `__, you can see that `we used rule #2 `__ to specify that all :code:`GPTBlock` modules within :code:`GPT` should be wrapped. Alternatively, we could have easily attributed each of the blocks with :code:`block._fsdp_wrap = True` and it would have accomplished the same thing. Whatever style you prefer, it's up to you! diff --git a/docs/source/trainer/using_the_trainer.rst b/docs/source/trainer/using_the_trainer.rst index 59ba3294e1..385fd10677 100644 --- a/docs/source/trainer/using_the_trainer.rst +++ b/docs/source/trainer/using_the_trainer.rst @@ -417,7 +417,6 @@ To enable FSDP, simply pass in as shown below: fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e9, 'cpu_offload': False, # Not supported yet 'mixed_precision': 'DEFAULT', 'backward_prefetch': 'BACKWARD_POST', diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 384064da33..2ca6e83bae 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -799,7 +799,6 @@ def test_hf_fsdp(tiny_bert_config, tiny_bert_tokenizer): fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE', diff --git a/tests/test_events.py b/tests/test_events.py index 544a2cc49b..63c686179f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -98,7 +98,6 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p if use_fsdp: fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE', diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index 24017cfd38..f34ba3862d 100644 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -176,7 +176,6 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path if fsdp: fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE', diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 552540a8a8..7d860eb5d6 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -92,7 +92,6 @@ def get_trainer( optimizers=optim, train_dataloader=dataloader, fsdp_config={ - 'min_params': 16, 'state_dict_type': fsdp_state_dict_type, 'sharding_strategy': sharding_strategy, 'sharded_ckpt_prefix_dir': fsdp_sharded_ckpt_prefix_dir, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 09311aa5d2..98d05197cc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -607,7 +607,6 @@ def test_fsdp( fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE', @@ -652,7 +651,6 @@ def test_fsdp_torch_compile( ): fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'min_params': 1e8, 'cpu_offload': False, 'mixed_precision': 'PURE', 'backward_prefetch': 'BACKWARD_PRE',