Skip to content

Commit

Permalink
revert monolithic cpkt + include sharded cpkt
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Jun 11, 2024
1 parent 78f00f1 commit e65e4fd
Showing 1 changed file with 34 additions and 115 deletions.
149 changes: 34 additions & 115 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,21 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.gpu
@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning')
@pytest.mark.parametrize(
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp',
'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp',
[
pytest.param('adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', True, True, False, False,
pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False,
marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only
pytest.param('adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)),
pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)),
],
)
def test_fsdp_full_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -318,106 +319,12 @@ def test_fsdp_full_state_dict_load(
run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)

tp_config = None
if use_tp:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
tp_config = {
'tensor_parallel_degree': 2,
'layer_plan': {
'module.0': ColwiseParallel(),
'module.2': RowwiseParallel(),
},
}

trainer1 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
run_name=run_name,
precision=precision,
autoresume=autoresume,
optimizer=optimizer,
fsdp_config=fsdp_config,
tp_config=tp_config,
)
trainer1.fit()
state_dict_from_trainer1 = trainer1.state.state_dict()
trainer1.close()
load_path = str(save_folder / pathlib.Path('rank{rank}.pt'))
trainer2 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
load_path=load_path,
run_name=run_name,
precision=precision,
autoresume=autoresume,
max_duration='4ba',
optimizer=optimizer,
fsdp_config=fsdp_config,
save_weights_only=save_weights_only,
load_weights_only=load_weights_only,
tp_config=tp_config,
)
state_dict_from_trainer2 = trainer2.state.state_dict()

if dist.get_global_rank() == 0:
_compare_model_params_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
if not load_weights_only:
_compare_optims_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
_compare_metrics_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
# Continue to fit to make sure we can continue training.
trainer2.fit()
trainer2.close()


@pytest.mark.gpu
@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning')
@pytest.mark.parametrize(
'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,data_parallel_shard,use_tp',
[
pytest.param(4, 'adam', False, 'amp_bf16', False, False, 2, False, marks=pytest.mark.world_size(4)),
],
)
def test_fsdp_full_state_dict_load_with_hsdp(
world_size: int,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
optimizer: str,
save_weights_only: bool,
load_weights_only: bool,
data_parallel_shard: int,
use_tp: bool,
):
if autoresume:
run_name = 'my-cool-autoresume-run'
else:
run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

data_parallel_replicate_degree = world_size // data_parallel_shard
fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sharding_strategy='HYBRID_SHARD',
data_parallel_shard_degree=data_parallel_shard,
data_parallel_replicate_degree=data_parallel_replicate_degree,
)

tp_config = None
if use_tp:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
Expand Down Expand Up @@ -871,16 +778,18 @@ def mock_get_checkpoint_validation_function():
@pytest.mark.gpu
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize(
'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp',
'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp,data_parallel_shard_degree',
[
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, False, -1, marks=pytest.mark.world_size(2)),
pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, False, -1, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, True, 4, marks=pytest.mark.world_size(4)),
pytest.param(4, False, 'adamw', 'amp_bf16', False, None, False, True, True, 2, marks=pytest.mark.world_size(4)),
],
)
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
Expand All @@ -896,11 +805,14 @@ def test_fsdp_partitioned_state_dict_load(
load_ignore_keys: Union[list[str], None],
use_symlink: bool,
use_tp: bool,
use_hsdp: bool,
data_parallel_shard_degree: int,
use_remote,
s3_bucket,
s3_ephemeral_prefix,
request,
):
# data_parallel_shard_degree will only be used if use_hsdp
if weights_only and autoresume:
pytest.skip('Weights only with autoresume is not supported')
if use_tp and version.parse(torch.__version__) < version.parse('2.3.0'):
Expand All @@ -922,10 +834,17 @@ def test_fsdp_partitioned_state_dict_load(

save_filename = 'ba{batch}-rank{rank}.pt'

fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
data_parallel_shard_degree=data_parallel_shard_degree,
data_parallel_replicate_degree=world_size//data_parallel_shard_degree,
)
else:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
tp_config = None
if use_tp:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
tp_config = {
'tensor_parallel_degree': 2,
Expand Down

0 comments on commit e65e4fd

Please sign in to comment.