Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 13, 2024
1 parent 1d5928d commit b830740
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
SimpleModel,
world_size,
)
from icecream import install
install()

@pytest.mark.gpu
@world_size(4)
Expand Down Expand Up @@ -304,20 +302,20 @@ def test_tp_forward(world_size: int):
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_gradients(world_size: int):
"""Test that DDP, FSDP, TP-FSDP output the same gradients."""
# from icecream import ic

# DDP gradients
ddp_trainer = get_ddp_trainer()
ddp_out = forward_pass(ddp_trainer)
torch.sum(ddp_out).backward()
# ddp_trainer.state.optimizers[0].step()
ddp_trainer.state.optimizers[0].step()
ddp_trainer.close()
ddp_state_dict = ddp_trainer.state.state_dict()

ic('ddp_trainer')
# ic(ddp_trainer.state.model.module)
# for name, param in ddp_trainer.state.model.named_parameters():
# if param.grad is not None:
# ic(name, param.grad.shape, param.grad)
ic(ddp_trainer.state.model.module)
for name, param in ddp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.grad.shape, param.grad)

# FSDP gradients
fsdp_trainer = get_fsdp_trainer()
Expand All @@ -327,10 +325,10 @@ def test_tp_gradients(world_size: int):
fsdp_state_dict = fsdp_trainer.state.state_dict()

ic('fsdp_trainer')
# ic(fsdp_trainer.state.model.module)
# for name, param in fsdp_trainer.state.model.named_parameters():
# if param.grad is not None:
# ic(name, param.grad.shape, param.grad)
ic(fsdp_trainer.state.model.module)
for name, param in fsdp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.grad.shape, param.grad)

# TP-FSDP gradients
tp_fsdp_trainer = get_tp_fsdp_trainer()
Expand All @@ -339,7 +337,7 @@ def test_tp_gradients(world_size: int):
tp_fsdp_trainer.close()
tp_fsdp_state_dict = tp_fsdp_trainer.state.state_dict()

ic('tp_fsdp_trainer')
print('tp_fsdp_trainer')
# ic(tp_fsdp_trainer.state.model.module)
# for name, param in tp_fsdp_trainer.state.model.named_parameters():
# if param.grad is not None:
Expand Down

0 comments on commit b830740

Please sign in to comment.