From 64a1c1d1bc4a838eee38cd1620ce4e79f6ea2334 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 16 Sep 2024 15:22:10 +0000 Subject: [PATCH] update tests --- tests/trainer/test_tp.py | 58 +++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index a645633211..8442ca5cb8 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -22,6 +22,7 @@ world_size, ) + @pytest.mark.gpu @world_size(4) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @@ -149,7 +150,7 @@ def __init__( def get_trainer( parallelism_config: Optional[ParallelismConfig] = None, size: int = 4, - batch_size: int = 4, + batch_size: int = 1, num_classes: int = 2, num_features: int = 6, seed: int = 42, @@ -180,13 +181,15 @@ def get_trainer( parallelism_config=parallelism_config, callbacks=[MemoryMonitor()], loggers=[InMemoryLogger()], + progress_bar=False, + log_to_console=False, ) return trainer def get_ddp_trainer( size: int = 4, - batch_size: int = 4, + batch_size: int = 1, num_classes: int = 2, num_features: int = 6, seed: int = 42, @@ -205,7 +208,7 @@ def get_ddp_trainer( def get_fsdp_trainer( size: int = 4, - batch_size: int = 4, + batch_size: int = 1, num_classes: int = 2, num_features: int = 6, seed: int = 42, @@ -232,7 +235,7 @@ def get_fsdp_trainer( def get_tp_fsdp_trainer( size: int = 4, - batch_size: int = 4, + batch_size: int = 1, num_classes: int = 2, num_features: int = 6, seed: int = 42, @@ -306,34 +309,38 @@ def test_tp_gradients(world_size: int): # DDP gradients ddp_trainer = get_ddp_trainer() - ddp_out = forward_pass(ddp_trainer) - torch.sum(ddp_out).backward() - ddp_trainer.state.optimizers[0].step() + ddp_trainer.fit() + # ddp_out = forward_pass(ddp_trainer) + # torch.sum(ddp_out).backward() + # ddp_trainer.state.optimizers[0].step() ddp_trainer.close() + ddp_state_dict = ddp_trainer.state.state_dict() - ic('ddp_trainer') - ic(ddp_trainer.state.model.module) - for name, param in ddp_trainer.state.model.named_parameters(): - if param.grad is not None: - ic(name, param.grad.shape, param.grad) + print('ddp_trainer') + # ic(ddp_trainer.state.model.module) + # for name, param in ddp_trainer.state.model.named_parameters(): + # if param.grad is not None: + # ic(name, param.grad.shape, param.grad) # FSDP gradients fsdp_trainer = get_fsdp_trainer() - fsdp_out = forward_pass(fsdp_trainer) - torch.sum(fsdp_out).backward() + fsdp_trainer.fit() + # fsdp_out = forward_pass(fsdp_trainer) + # torch.sum(fsdp_out).backward() fsdp_trainer.close() fsdp_state_dict = fsdp_trainer.state.state_dict() - ic('fsdp_trainer') - ic(fsdp_trainer.state.model.module) - for name, param in fsdp_trainer.state.model.named_parameters(): - if param.grad is not None: - ic(name, param.grad.shape, param.grad) + print('fsdp_trainer') + # ic(fsdp_trainer.state.model.module) + # for name, param in fsdp_trainer.state.model.named_parameters(): + # if param.grad is not None: + # ic(name, param.grad.shape, param.grad) # TP-FSDP gradients tp_fsdp_trainer = get_tp_fsdp_trainer() - tp_fsdp_out = forward_pass(tp_fsdp_trainer) - torch.sum(tp_fsdp_out).backward() + tp_fsdp_trainer.fit() + # tp_fsdp_out = forward_pass(tp_fsdp_trainer) + # torch.sum(tp_fsdp_out).backward() tp_fsdp_trainer.close() tp_fsdp_state_dict = tp_fsdp_trainer.state.state_dict() @@ -355,27 +362,28 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]: @pytest.mark.gpu @world_size(4) +@pytest.mark.parametrize('batch_size', [1, 4]) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') -def test_tp_fit(world_size: int): +def test_tp_fit(batch_size: int, world_size: int): """Test that DDP, FSDP, TP-FSDP have the same trainer.fit(), i.e. output the same loss and accuracy.""" size = 1024 # make enough data to train for multiple steps # DDP fit - ddp_trainer = get_ddp_trainer(size=size) + ddp_trainer = get_ddp_trainer(size=size, batch_size=batch_size) ddp_trainer.fit() ddp_trainer.close() ddp_stats = get_stats(ddp_trainer) # FSDP fit - fsdp_trainer = get_fsdp_trainer(size=size) + fsdp_trainer = get_fsdp_trainer(size=size, batch_size=batch_size) fsdp_trainer.fit() fsdp_trainer.close() fsdp_stats = get_stats(fsdp_trainer) # TP-FSDP fit - tp_fsdp_trainer = get_tp_fsdp_trainer(size=size) + tp_fsdp_trainer = get_tp_fsdp_trainer(size=size, batch_size=batch_size) tp_fsdp_trainer.fit() tp_fsdp_trainer.close() tp_fsdp_stats = get_stats(tp_fsdp_trainer)