Skip to content

Commit

Permalink
init gradient test
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 13, 2024
1 parent ca3a808 commit 7f0ab5b
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
SimpleModel,
world_size,
)

from icecream import install
install()

@pytest.mark.gpu
@world_size(4)
Expand Down Expand Up @@ -274,7 +275,7 @@ def forward_pass(trainer):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_forward(world_size: int):
"""Test that the forward pass with DDP, FSDP, TP-FSDP all output the same tensor."""
"""Test that DDP, FSDP, TP-FSDP do the same forward pass."""

# DDP forward pass
ddp_trainer = get_ddp_trainer()
Expand All @@ -297,6 +298,44 @@ def test_tp_forward(world_size: int):
), f'Outputs have different values: {ddp_out=} and {tp_fsdp_out=}'


@pytest.mark.gpu
@world_size(4)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_gradients(world_size: int):
"""Test that DDP, FSDP, TP-FSDP output the same gradients."""

# DDP gradients
ddp_trainer = get_ddp_trainer()
ddp_out = forward_pass(ddp_trainer)
torch.sum(ddp_out).backward()

ic('ddp_trainer')
for name, param in ddp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.shape, param.grad.shape, param, param.grad)

# FSDP gradients
fsdp_trainer = get_fsdp_trainer()
fsdp_out = forward_pass(fsdp_trainer)
torch.sum(fsdp_out).backward()

ic('fsdp_trainer')
for name, param in fsdp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.shape, param.grad.shape, param, param.grad)

# TP-FSDP gradients
tp_fsdp_trainer = get_tp_fsdp_trainer()
tp_fsdp_out = forward_pass(tp_fsdp_trainer)
torch.sum(tp_fsdp_out).backward()

ic('tp_fsdp_trainer')
for name, param in tp_fsdp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.shape, param.grad.shape, param, param.grad)


def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:
logger = trainer.logger.destinations[0]
stats = {
Expand All @@ -311,7 +350,7 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fit(world_size: int):
"""Test that trainer.fit() with DDP, FSDP, TP-FSDP all output the same loss and accuracy."""
"""Test that DDP, FSDP, TP-FSDP have the same trainer.fit(), i.e. output the same loss and accuracy."""

size = 1024 # make enough data to train for multiple steps

Expand Down

0 comments on commit 7f0ab5b

Please sign in to comment.