From 501950f8f9af108582de4431082db927b06ec765 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 6 Sep 2024 16:45:48 +0000 Subject: [PATCH] simpler test case --- TEST.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 6 deletions(-) diff --git a/TEST.py b/TEST.py index 9fa9d76fea..0bbc094c3b 100644 --- a/TEST.py +++ b/TEST.py @@ -39,8 +39,8 @@ from tests.common.markers import world_size from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent - -from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts, get_trainer +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts from icecream import install from icecream import ic @@ -48,6 +48,7 @@ ic.configureOutput(includeContext=True) def test_1(use_tp: bool): + from tests.trainer.test_fsdp_checkpoint import get_trainer tmp_path: pathlib.Path = 'tmp' autoresume: bool = False @@ -151,11 +152,93 @@ def test_1(use_tp: bool): trainer2.close() +def test_2(use_tp: bool): + from tests.trainer.test_fsdp_checkpoint import SimpleMLP + + tmp_path: pathlib.Path = 'tmp' + autoresume: bool = False + precision: str = 'amp_bf16' + optimizer: str = 'adam' + save_weights_only: bool = False + load_weights_only: bool = False + + run_name = None + save_folder = tmp_path + save_filename = 'rank{rank}.pt' + + fsdp_config = FSDPConfig(sharded_ckpt_prefix_dir='ba{batch}') + tp_config = None + if use_tp: + tp_config = { + 'tensor_parallel_degree': 2, + 'layer_plan': {'module.0': ColwiseParallel(), 'module.2': RowwiseParallel()}, + } + + model_init_device: str = 'cpu' + save_overwrite: bool = False + num_features: int = 4 + num_classes: int = 2 + load_path: Optional[str] = None + max_duration: Optional[int | str | Time] = '2ba' + save_interval: str | int | Time | Callable[[State, Event], bool] = '2ba' + save_weights_only: bool = False + load_weights_only: bool = False + load_ignore_keys: Optional[list[str] | Callable[[dict], None]] = None + algorithms: Optional[Algorithm | Sequence[Algorithm]] = None + save_num_checkpoints_to_keep: int = -1 + train_metrics: Optional[Any] = None + val_metrics: Optional[Any] = None + + model = SimpleMLP(num_features=num_features, num_classes=num_classes, train_metrics=train_metrics, val_metrics=val_metrics) + model.module.to(model_init_device) + dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=8,) + optim = torch.optim.Adam(params=model.parameters()) + + parallelism_config: dict[str, Union[FSDPConfig, dict[str, Any]]] = {'fsdp': fsdp_config} + if tp_config is not None: + parallelism_config['tp'] = tp_config + + trainer1 = Trainer( + algorithms=algorithms, + model=model, + optimizers=optim, + train_dataloader=dataloader, + parallelism_config=parallelism_config, + save_folder=str(save_folder), + max_duration=max_duration, + save_interval=save_interval, + save_filename=save_filename, + save_overwrite=save_overwrite, + precision=precision, + load_path=load_path, + progress_bar=False, + log_to_console=False, + autoresume=autoresume, + run_name=run_name, + save_latest_filename='latest-rank{rank}.pt', + save_weights_only=save_weights_only, + load_weights_only=load_weights_only, + save_num_checkpoints_to_keep=save_num_checkpoints_to_keep, + load_ignore_keys=load_ignore_keys, + ) + + if use_tp: + assert trainer1.state.tp_config is not None + assert isinstance(trainer1.state.tp_config, TPConfig) + + ic('Before trainer 1 fit') + print('Before trainer 1 fit') + trainer1.fit() + print('After trainer 1 fit') + if __name__ == '__main__': - # print('*'*70, '\nuse_tp=False\n', '*'*70) - # test_1(use_tp=False) - # print('*'*70, '\nDone\n', '*'*70) + test = test_2 + + print('*'*70, '\nuse_tp=False\n', '*'*70) + test(use_tp=False) + print('*'*70, '\nDone\n', '*'*70) print('*'*70, '\nuse_tp=True\n', '*'*70) - test_1(use_tp=True) + test(use_tp=True) print('*'*70, '\nDone\n', '*'*70)