Skip to content

Commit

Permalink
simpler test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 6, 2024
1 parent ca646e7 commit 501950f
Showing 1 changed file with 89 additions and 6 deletions.
95 changes: 89 additions & 6 deletions TEST.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@
from tests.common.markers import world_size
from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent


from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts, get_trainer
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts
from icecream import install
from icecream import ic

install()
ic.configureOutput(includeContext=True)

def test_1(use_tp: bool):
from tests.trainer.test_fsdp_checkpoint import get_trainer

tmp_path: pathlib.Path = 'tmp'
autoresume: bool = False
Expand Down Expand Up @@ -151,11 +152,93 @@ def test_1(use_tp: bool):
trainer2.close()


def test_2(use_tp: bool):
from tests.trainer.test_fsdp_checkpoint import SimpleMLP

tmp_path: pathlib.Path = 'tmp'
autoresume: bool = False
precision: str = 'amp_bf16'
optimizer: str = 'adam'
save_weights_only: bool = False
load_weights_only: bool = False

run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

fsdp_config = FSDPConfig(sharded_ckpt_prefix_dir='ba{batch}')
tp_config = None
if use_tp:
tp_config = {
'tensor_parallel_degree': 2,
'layer_plan': {'module.0': ColwiseParallel(), 'module.2': RowwiseParallel()},
}

model_init_device: str = 'cpu'
save_overwrite: bool = False
num_features: int = 4
num_classes: int = 2
load_path: Optional[str] = None
max_duration: Optional[int | str | Time] = '2ba'
save_interval: str | int | Time | Callable[[State, Event], bool] = '2ba'
save_weights_only: bool = False
load_weights_only: bool = False
load_ignore_keys: Optional[list[str] | Callable[[dict], None]] = None
algorithms: Optional[Algorithm | Sequence[Algorithm]] = None
save_num_checkpoints_to_keep: int = -1
train_metrics: Optional[Any] = None
val_metrics: Optional[Any] = None

model = SimpleMLP(num_features=num_features, num_classes=num_classes, train_metrics=train_metrics, val_metrics=val_metrics)
model.module.to(model_init_device)
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=8,)
optim = torch.optim.Adam(params=model.parameters())

parallelism_config: dict[str, Union[FSDPConfig, dict[str, Any]]] = {'fsdp': fsdp_config}
if tp_config is not None:
parallelism_config['tp'] = tp_config

trainer1 = Trainer(
algorithms=algorithms,
model=model,
optimizers=optim,
train_dataloader=dataloader,
parallelism_config=parallelism_config,
save_folder=str(save_folder),
max_duration=max_duration,
save_interval=save_interval,
save_filename=save_filename,
save_overwrite=save_overwrite,
precision=precision,
load_path=load_path,
progress_bar=False,
log_to_console=False,
autoresume=autoresume,
run_name=run_name,
save_latest_filename='latest-rank{rank}.pt',
save_weights_only=save_weights_only,
load_weights_only=load_weights_only,
save_num_checkpoints_to_keep=save_num_checkpoints_to_keep,
load_ignore_keys=load_ignore_keys,
)

if use_tp:
assert trainer1.state.tp_config is not None
assert isinstance(trainer1.state.tp_config, TPConfig)

ic('Before trainer 1 fit')
print('Before trainer 1 fit')
trainer1.fit()
print('After trainer 1 fit')

if __name__ == '__main__':
# print('*'*70, '\nuse_tp=False\n', '*'*70)
# test_1(use_tp=False)
# print('*'*70, '\nDone\n', '*'*70)
test = test_2

print('*'*70, '\nuse_tp=False\n', '*'*70)
test(use_tp=False)
print('*'*70, '\nDone\n', '*'*70)

print('*'*70, '\nuse_tp=True\n', '*'*70)
test_1(use_tp=True)
test(use_tp=True)
print('*'*70, '\nDone\n', '*'*70)

0 comments on commit 501950f

Please sign in to comment.