Skip to content

Commit

Permalink
Removing min_params (mosaicml#2494)
Browse files Browse the repository at this point in the history
* Removing min_params

* formatting?

* removing overlap with another commit
  • Loading branch information
bcui19 committed Aug 30, 2023
1 parent 5fd5ffe commit b8cc2ac
Show file tree
Hide file tree
Showing 8 changed files with 4 additions and 24 deletions.
16 changes: 3 additions & 13 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,7 @@ def sync_hook(*args):
f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` '
f'with sharding strategy `{sharding_map_key}.`')

if fsdp_config.get('min_params') is not None:
warnings.warn(DeprecationWarning('`min_params` in FSDP config will be deprecated in composer version 0.16.0.'))

backward_prefetch = backward_prefetch_map[fsdp_config['backward_prefetch'].upper()]
min_params = int(float(fsdp_config.get('min_params', 1e9)))
activation_checkpointing = fsdp_config['activation_checkpointing']
activation_cpu_offload = fsdp_config['activation_cpu_offload']
sync_module_states = fsdp_config['sync_module_states']
Expand Down Expand Up @@ -441,20 +437,15 @@ def _param_init_fn(module: torch.nn.Module) -> None:

# Choose which modules to FSDP wrap according to the following priority:
# If module has attribute `module._fsdp_wrap = ...`, always respect it
# Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true
# Or if unwrapped params in module in greater than or equal to fsdp_config.min_params
# Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true.
def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
if recurse:
return True
should_be_wrapped = False
if hasattr(module, '_fsdp_wrap'):
should_be_wrapped = bool(module._fsdp_wrap)
else:
is_large = nonwrapped_numel >= min_params
if hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
should_be_wrapped = obj.fsdp_wrap_fn(module) or is_large
else:
should_be_wrapped = is_large
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
should_be_wrapped = obj.fsdp_wrap_fn(module)

if should_be_wrapped and auto_microbatching:
module.register_forward_hook(sync_hook)
Expand Down Expand Up @@ -540,7 +531,6 @@ def _check_fn(module: torch.nn.Module) -> bool:
print(f'FSDP: Using cpu_offload={cpu_offload}')
print(f'FSDP: Using mixed_precision={mixed_precision}')
print(f'FSDP: Using backward_prefetch={backward_prefetch}')
print(f'FSDP: Using min_params={min_params}')
print(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
print(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
print(f'FSDP: Using sync_module_states={sync_module_states}')
Expand Down
5 changes: 1 addition & 4 deletions docs/source/notes/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ The full spec and defaults for Composer's `fsdp_config` is here:
fsdp_config = {
'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD'
'min_params': float # Default: 1e8
'cpu_offload': bool = True | False, # Default: False, cpu_offload not supported yet
'mixed_precision': str = 'FULL' | 'DEFAULT' | 'PURE', # Default: 'DEFAULT'
# Note: you can explicitly provide a dictionary too
Expand Down Expand Up @@ -279,7 +278,6 @@ An example code snippet for using FSDP with composer is provided below:
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False, # Not supported yet
'mixed_precision': 'DEFAULT',
'backward_prefetch': 'BACKWARD_POST',
Expand Down Expand Up @@ -310,9 +308,8 @@ To make auto-wrapping easier on users, Composer uses a custom auto wrap policy t

1) If any module is attributed with :code:`module._fsdp_wrap = True | False`, that choice will be respected.
2) If the root module (e.g. `GPT`) defines a function :code:`def fsdp_wrap_fn(module: torch.nn.Module) -> bool`, then that function will be used to evaluate the root module's children.
3) If any module has more parameters than :code:`fsdp_config['min_params']`, it will be wrapped.

These rules are meant to make it easy for users to modify existing models for usage with FSDP. You can either add attributes to modules you want to wrap (#1), define a filter (#2), or make no changes at all and just use the size-based policy via :code:`fsdp_config['min_params'] = ...` (#3).
These rules are meant to make it easy for users to modify existing models for usage with FSDP. You can either add attributes to modules you want to wrap (#1) or define a filter (#2).

In `gpt.py <https://github.com/mosaicml/examples/blob/6972fe3000d5a5480d8757ff710965514155e8db/llm/llm/gpt.py>`__, you can see that `we used rule #2 <https://github.com/mosaicml/examples/blob/6972fe3000d5a5480d8757ff710965514155e8db/llm/llm/gpt.py#L172>`__ to specify that all :code:`GPTBlock` modules within :code:`GPT` should be wrapped. Alternatively, we could have easily attributed each of the blocks with :code:`block._fsdp_wrap = True` and it would have accomplished the same thing. Whatever style you prefer, it's up to you!

Expand Down
1 change: 0 additions & 1 deletion docs/source/trainer/using_the_trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ To enable FSDP, simply pass in as shown below:
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e9,
'cpu_offload': False, # Not supported yet
'mixed_precision': 'DEFAULT',
'backward_prefetch': 'BACKWARD_POST',
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,6 @@ def test_hf_fsdp(tiny_bert_config, tiny_bert_tokenizer):

fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
Expand Down
1 change: 0 additions & 1 deletion tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p
if use_fsdp:
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path
if fsdp:
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def get_trainer(
optimizers=optim,
train_dataloader=dataloader,
fsdp_config={
'min_params': 16,
'state_dict_type': fsdp_state_dict_type,
'sharding_strategy': sharding_strategy,
'sharded_ckpt_prefix_dir': fsdp_sharded_ckpt_prefix_dir,
Expand Down
2 changes: 0 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,6 @@ def test_fsdp(

fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
Expand Down Expand Up @@ -652,7 +651,6 @@ def test_fsdp_torch_compile(
):
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'min_params': 1e8,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
Expand Down

0 comments on commit b8cc2ac

Please sign in to comment.