Skip to content

Commit

Permalink
rm world_size param
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Jun 21, 2024
1 parent bb6150d commit 635f92e
Showing 1 changed file with 32 additions and 46 deletions.
78 changes: 32 additions & 46 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,28 +289,27 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.gpu
@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning')
@pytest.mark.parametrize(
'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp,use_hsdp',
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp,use_hsdp',
[
pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False, False,
pytest.param('adam', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adamw', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', True, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_fp16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', True, True, False, False, False,
marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only
pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, 'adamw', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, 'adam', True, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, 'adam', False, 'amp_fp16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, 'adam', False, 'amp_bf16', True, True, False, False, True,
pytest.param('adam', False, 'amp_bf16', False, True, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adamw', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', True, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_fp16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', True, True, False, False, True,
marks=pytest.mark.world_size(4)), # save_weights_only requires load_weights_only
pytest.param(4, 'adam', False, 'amp_bf16', False, True, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, True, False, False, True, marks=pytest.mark.world_size(4)),
],
)
def test_fsdp_full_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -334,7 +333,7 @@ def test_fsdp_full_state_dict_load(
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
data_parallel_shard_degree=world_size // 2,
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
)
else:
Expand Down Expand Up @@ -796,15 +795,14 @@ def mock_get_checkpoint_validation_function():
@pytest.mark.gpu
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize(
'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp',
'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp',
[
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(
2,
False,
'adamw',
'amp_bf16',
Expand All @@ -815,33 +813,21 @@ def mock_get_checkpoint_validation_function():
False,
marks=pytest.mark.world_size(2),
),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, True, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adam', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adamw', 'amp_fp16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adamw', 'amp_bf16', True, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(
4,
False,
'adamw',
'amp_bf16',
False,
['rng'],
False,
False,
True,
marks=pytest.mark.world_size(4),
),
pytest.param(4, False, 'adamw', 'amp_bf16', False, None, True, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, ['rng'], False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, True, marks=pytest.mark.world_size(4)),
],
)
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
def test_fsdp_partitioned_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand Down Expand Up @@ -883,7 +869,7 @@ def test_fsdp_partitioned_state_dict_load(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
state_dict_type='sharded',
data_parallel_shard_degree=world_size // 2,
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
sync_module_states=True,
)
Expand Down

0 comments on commit 635f92e

Please sign in to comment.