From e65e4fdf75794de33904e08ad76e013c132eb5e9 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Tue, 11 Jun 2024 13:51:56 -0700 Subject: [PATCH] revert monolithic cpkt + include sharded cpkt --- tests/trainer/test_fsdp_checkpoint.py | 149 ++++++-------------------- 1 file changed, 34 insertions(+), 115 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 52a5a43c2e..ffa73e42d2 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -289,20 +289,21 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2): @pytest.mark.gpu @pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning') @pytest.mark.parametrize( - 'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp', + 'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp', [ - pytest.param('adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param('adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param('adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param('adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param('adam', False, 'amp_bf16', True, True, False, False, + pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False, marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only - pytest.param('adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)), - pytest.param('adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)), - pytest.param('adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)), + pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)), + pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)), ], ) def test_fsdp_full_state_dict_load( + world_size, tmp_path: pathlib.Path, autoresume: bool, precision: str, @@ -318,106 +319,12 @@ def test_fsdp_full_state_dict_load( run_name = None save_folder = tmp_path save_filename = 'rank{rank}.pt' + fsdp_config = FSDPConfig( sharded_ckpt_prefix_dir='ba{batch}', sync_module_states=load_monolith_rank0_only, load_monolith_rank0_only=load_monolith_rank0_only, ) - - tp_config = None - if use_tp: - from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel - tp_config = { - 'tensor_parallel_degree': 2, - 'layer_plan': { - 'module.0': ColwiseParallel(), - 'module.2': RowwiseParallel(), - }, - } - - trainer1 = get_trainer( - save_folder=str(save_folder), - save_filename=save_filename, - run_name=run_name, - precision=precision, - autoresume=autoresume, - optimizer=optimizer, - fsdp_config=fsdp_config, - tp_config=tp_config, - ) - trainer1.fit() - state_dict_from_trainer1 = trainer1.state.state_dict() - trainer1.close() - load_path = str(save_folder / pathlib.Path('rank{rank}.pt')) - trainer2 = get_trainer( - save_folder=str(save_folder), - save_filename=save_filename, - load_path=load_path, - run_name=run_name, - precision=precision, - autoresume=autoresume, - max_duration='4ba', - optimizer=optimizer, - fsdp_config=fsdp_config, - save_weights_only=save_weights_only, - load_weights_only=load_weights_only, - tp_config=tp_config, - ) - state_dict_from_trainer2 = trainer2.state.state_dict() - - if dist.get_global_rank() == 0: - _compare_model_params_between_state_dicts( - state_dict_from_trainer1, - state_dict_from_trainer2, - ) - if not load_weights_only: - _compare_optims_between_state_dicts( - state_dict_from_trainer1, - state_dict_from_trainer2, - ) - _compare_metrics_between_state_dicts( - state_dict_from_trainer1, - state_dict_from_trainer2, - ) - # Continue to fit to make sure we can continue training. - trainer2.fit() - trainer2.close() - - -@pytest.mark.gpu -@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning') -@pytest.mark.parametrize( - 'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,data_parallel_shard,use_tp', - [ - pytest.param(4, 'adam', False, 'amp_bf16', False, False, 2, False, marks=pytest.mark.world_size(4)), - ], -) -def test_fsdp_full_state_dict_load_with_hsdp( - world_size: int, - tmp_path: pathlib.Path, - autoresume: bool, - precision: str, - optimizer: str, - save_weights_only: bool, - load_weights_only: bool, - data_parallel_shard: int, - use_tp: bool, -): - if autoresume: - run_name = 'my-cool-autoresume-run' - else: - run_name = None - save_folder = tmp_path - save_filename = 'rank{rank}.pt' - - data_parallel_replicate_degree = world_size // data_parallel_shard - fsdp_config = FSDPConfig( - sharded_ckpt_prefix_dir='ba{batch}', - sharding_strategy='HYBRID_SHARD', - data_parallel_shard_degree=data_parallel_shard, - data_parallel_replicate_degree=data_parallel_replicate_degree, - ) - tp_config = None if use_tp: from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -871,16 +778,18 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) @pytest.mark.parametrize( - 'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp', + 'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp,data_parallel_shard_degree', [ - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, False, -1, marks=pytest.mark.world_size(2)), + pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, False, -1, marks=pytest.mark.world_size(4)), + pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, True, 4, marks=pytest.mark.world_size(4)), + pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, True, 2, marks=pytest.mark.world_size(4)), ], ) @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -896,11 +805,14 @@ def test_fsdp_partitioned_state_dict_load( load_ignore_keys: Union[list[str], None], use_symlink: bool, use_tp: bool, + use_hsdp: bool, + data_parallel_shard_degree: int, use_remote, s3_bucket, s3_ephemeral_prefix, request, ): + # data_parallel_shard_degree will only be used if use_hsdp if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') if use_tp and version.parse(torch.__version__) < version.parse('2.3.0'): @@ -922,10 +834,17 @@ def test_fsdp_partitioned_state_dict_load( save_filename = 'ba{batch}-rank{rank}.pt' - fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') + if use_hsdp: + fsdp_config = FSDPConfig( + sharding_strategy='HYBRID_SHARD', + sharded_ckpt_prefix_dir='ba{batch}', + data_parallel_shard_degree=data_parallel_shard_degree, + data_parallel_replicate_degree=world_size//data_parallel_shard_degree, + ) + else: + fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') tp_config = None if use_tp: - fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel tp_config = { 'tensor_parallel_degree': 2,