diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index 6b95b40a0c..a645633211 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -21,8 +21,6 @@ SimpleModel, world_size, ) -from icecream import install -install() @pytest.mark.gpu @world_size(4) @@ -304,20 +302,20 @@ def test_tp_forward(world_size: int): @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_gradients(world_size: int): """Test that DDP, FSDP, TP-FSDP output the same gradients.""" + # from icecream import ic # DDP gradients ddp_trainer = get_ddp_trainer() ddp_out = forward_pass(ddp_trainer) torch.sum(ddp_out).backward() - # ddp_trainer.state.optimizers[0].step() + ddp_trainer.state.optimizers[0].step() ddp_trainer.close() - ddp_state_dict = ddp_trainer.state.state_dict() ic('ddp_trainer') - # ic(ddp_trainer.state.model.module) - # for name, param in ddp_trainer.state.model.named_parameters(): - # if param.grad is not None: - # ic(name, param.grad.shape, param.grad) + ic(ddp_trainer.state.model.module) + for name, param in ddp_trainer.state.model.named_parameters(): + if param.grad is not None: + ic(name, param.grad.shape, param.grad) # FSDP gradients fsdp_trainer = get_fsdp_trainer() @@ -327,10 +325,10 @@ def test_tp_gradients(world_size: int): fsdp_state_dict = fsdp_trainer.state.state_dict() ic('fsdp_trainer') - # ic(fsdp_trainer.state.model.module) - # for name, param in fsdp_trainer.state.model.named_parameters(): - # if param.grad is not None: - # ic(name, param.grad.shape, param.grad) + ic(fsdp_trainer.state.model.module) + for name, param in fsdp_trainer.state.model.named_parameters(): + if param.grad is not None: + ic(name, param.grad.shape, param.grad) # TP-FSDP gradients tp_fsdp_trainer = get_tp_fsdp_trainer() @@ -339,7 +337,7 @@ def test_tp_gradients(world_size: int): tp_fsdp_trainer.close() tp_fsdp_state_dict = tp_fsdp_trainer.state.state_dict() - ic('tp_fsdp_trainer') + print('tp_fsdp_trainer') # ic(tp_fsdp_trainer.state.model.module) # for name, param in tp_fsdp_trainer.state.model.named_parameters(): # if param.grad is not None: