Skip to content

Commit

Permalink
real nice
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 19, 2024
1 parent 3fa3807 commit 761001c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 64 deletions.
7 changes: 5 additions & 2 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(
self.shape: Sequence[int] = shape
self.num_classes: int = num_classes
self.device: Optional[torch.device] = device
self.generator: Optional[torch.Generator] = None
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None

Expand All @@ -79,7 +78,11 @@ def __getitem__(self, index: int):
# Note: lazily generate data so it runs after Composer seeds everything, giving the same
# dataset across multiple calls when using the same seed.
if self.x is None:
self.x = torch.randn(self.size, *self.shape, device=self.device)
self.x = torch.randn(
self.size,
*self.shape,
device=self.device,
)
ic(self.x, self.x.device)
if self.y is None:
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)
Expand Down
93 changes: 31 additions & 62 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, TypeVar

E = TypeVar('E', bound=BaseException)

import numpy as np
import pytest
Expand Down Expand Up @@ -105,6 +107,7 @@ def get_trainer(
"""Trainer for a simple model with any parallelism_config."""

reproducibility.seed_all(seed)

if replication:
dataset: Dataset = RandomClassificationDatasetReplicated(
shape=(num_features,),
Expand Down Expand Up @@ -246,6 +249,7 @@ def get_tp_fsdp_trainer(


def forward_pass(trainer):
reproducibility.seed_all(42)
batch = next(iter(trainer.state.train_dataloader))
output = trainer.state.model.forward(batch)
return output
Expand All @@ -260,7 +264,7 @@ def _replace_state_dict_name(state_dict: dict[str, Any], old_name: str, new_name
return state_dict


def _compare_modules(module1, module2, check_grad: bool = False):
def _compare_modules(module1: dict[str, Any], module2: dict[str, Any], check_grad: bool = False):
module_type = 'Gradients' if check_grad else 'Parameters'

for (param1_name, param1), (param2_name, param2) in zip(module1.items(), module2.items()):
Expand Down Expand Up @@ -289,9 +293,8 @@ def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer:
# However, calling `tp_fsdp_trainer.state.state_dict()` directly causes a NCCL timeout
# due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095/.
# As a workaround, we use `tp_fsdp_trainer.state.model.named_parameters()` instead.
# This issues only exists in `tp_fsdp_trainer.state.state_dict()` when we use TP and
# FSDP together; it does not arise when calling `ddp_trainer.state.state_dict()` or
# `fsdp_trainer.state.state_dict()`.
# This issues only exists with `tp_fsdp_trainer.state.state_dict()` and does not
# arise when calling `ddp_trainer.state.state_dict()` or `fsdp_trainer.state.state_dict()`.
with FSDP.summon_full_params(fsdp_trainer.state.model, with_grads=check_grad):
with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=check_grad):
ddp_params = dict(ddp_trainer.state.model.named_parameters())
Expand All @@ -311,6 +314,15 @@ def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer:
_compare_modules(ddp_params, fsdp_params, check_grad=check_grad)


@contextmanager
def fail_without_replication(replication: int, exception: type[E], error_error_msg: str):
if replication:
yield
else:
with pytest.raises(exception, match=error_error_msg):
yield


@pytest.mark.gpu
@world_size(4)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+')
Expand Down Expand Up @@ -525,15 +537,6 @@ def test_tp_hang(world_size: int):
# print(tp_fsdp_state_dict_6)


@contextmanager
def fail_without_replication(replication: int, exception: Exception, error_error_msg: str):
if replication:
yield
else:
with pytest.raises(exception, match=error_error_msg):
yield


@pytest.mark.gpu
@world_size(4)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
Expand All @@ -546,67 +549,33 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0):
- updated weights
"""

# # Initialize trainers with DDP, FSDP, TP-FSDP
# ddp_trainer = get_ddp_trainer(replication=replication)
# fsdp_trainer = get_fsdp_trainer(replication=replication)
# tp_fsdp_trainer = get_tp_fsdp_trainer(replication=replication)

# # Ensure initial model weights are the same
# compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)

# # Forward pass
# ddp_out = forward_pass(ddp_trainer)
# fsdp_out = forward_pass(fsdp_trainer)
# tp_fsdp_out = forward_pass(tp_fsdp_trainer)

# # Ensure output of the forward pass is the same
# _compare_modules({'': ddp_out}, {'': fsdp_out})
# _compare_modules({'': ddp_out}, {'': tp_fsdp_out})
# _compare_modules({'': fsdp_out}, {'': tp_fsdp_out})

# # Compute gradients
# torch.sum(ddp_out).backward()
# torch.sum(fsdp_out).backward()
# torch.sum(tp_fsdp_out).backward()

# # Ensure the gradients are the same
# compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, check_grad=True)

# # Update the model weights
# ddp_trainer.state.optimizers[0].step()
# fsdp_trainer.state.optimizers[0].step()
# tp_fsdp_trainer.state.optimizers[0].step()

# # Ensure the updated weights are the same
# compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)

# DDP
# Initialize trainers with DDP, FSDP, TP-FSDP
ddp_trainer = get_ddp_trainer(replication=replication)
ddp_out = forward_pass(ddp_trainer)
torch.sum(ddp_out).backward()

# FSDP
fsdp_trainer = get_fsdp_trainer(replication=replication)
fsdp_out = forward_pass(fsdp_trainer)
torch.sum(fsdp_out).backward()

# TP-FSDP
tp_fsdp_trainer = get_tp_fsdp_trainer(replication=replication)
tp_fsdp_out = forward_pass(tp_fsdp_trainer)
torch.sum(tp_fsdp_out).backward()

# Ensure initial model weights are the same
compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)

# Forward pass
ddp_out = forward_pass(ddp_trainer)
fsdp_out = forward_pass(fsdp_trainer)
tp_fsdp_out = forward_pass(tp_fsdp_trainer)

# Ensure output of the forward pass is the same
with FSDP.summon_full_params(fsdp_trainer.state.model):
with FSDP.summon_full_params(tp_fsdp_trainer.state.model):
_compare_modules({'': ddp_out}, {'': fsdp_out})
_compare_modules({'': ddp_out}, {'': tp_fsdp_out})
_compare_modules({'': fsdp_out}, {'': tp_fsdp_out})

# Compute gradients
torch.sum(ddp_out).backward()
torch.sum(fsdp_out).backward()
torch.sum(tp_fsdp_out).backward()

# Ensure the gradients are the same
# We expect this test to fail without replication, i.e. replication=0
# We expect this test to fail without replication, i.e. when replication=0
error_error_msg = 'Gradients are not close enough:*'
with fail_without_replication(replication, AssertionError, error_error_msg):
compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, check_grad=True)
Expand All @@ -617,7 +586,7 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0):
tp_fsdp_trainer.state.optimizers[0].step()

# Ensure the updated weights are the same
# We expect this test to fail without replication, i.e. replication=0
# We expect this test to fail without replication, i.e. when replication=0
error_error_msg = 'Parameters are not close enough:*'
with fail_without_replication(replication, AssertionError, error_error_msg):
compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)
Expand Down Expand Up @@ -903,6 +872,6 @@ def test_tp_fsdp_state_dict(world_size: int):
if __name__ == '__main__':
import warnings
warnings.filterwarnings('ignore')
# test_tp_forwards_backwards(4, replication=2)
test_tp_forwards_backwards(4, replication=2)
test_tp_forwards_backwards(4, replication=0)
# test_tp_fsdp_state_dict(4)

0 comments on commit 761001c

Please sign in to comment.