Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 19, 2024
1 parent 0255b6b commit 288d4f5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 132 deletions.
3 changes: 0 additions & 3 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pytest
import torch
from icecream import ic
from PIL import Image
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision.datasets import VisionDataset
Expand Down Expand Up @@ -83,10 +82,8 @@ def __getitem__(self, index: int):
*self.shape,
device=self.device,
)
ic(self.x, self.x.device)
if self.y is None:
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)
ic(self.y, self.y.device)
return self.x[index], self.y[index]


Expand Down
215 changes: 86 additions & 129 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
SimpleModel,
world_size,
)
from tests.trainer.test_fsdp_checkpoint import get_mono_state_dict_from_sharded_one


class RandomClassificationDatasetReplicated(Dataset):
Expand Down Expand Up @@ -72,6 +71,9 @@ def __getitem__(self, idx):
if self.x is None and self.y is None:
self._generate_data()

assert self.x is not None
assert self.y is not None

rank_idx = idx // self.world_size
return self.x[rank_idx], self.y[rank_idx]

Expand Down Expand Up @@ -263,7 +265,13 @@ def _replace_state_dict_name(state_dict: dict[str, Any], old_name: str, new_name
return state_dict


def _compare_modules(module1: dict[str, Any], module2: dict[str, Any], check_grad: bool = False):
def compare_modules(
module1: dict[str, Any],
module2: dict[str, Any],
check_grad: bool = False,
atol: Optional[float] = None,
rtol: Optional[float] = None
):
module_type = 'Gradients' if check_grad else 'Parameters'

for (param1_name, param1), (param2_name, param2) in zip(module1.items(), module2.items()):
Expand All @@ -282,11 +290,20 @@ def _compare_modules(module1: dict[str, Any], module2: dict[str, Any], check_gra
torch.testing.assert_close(
param1,
param2,
atol=atol,
rtol=rtol,
msg=f'{module_type} are not close enough:\n{param1=}\n{param2=}',
)


def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer: Trainer, check_grad: bool = False):
def compare_models(
ddp_trainer: Trainer,
fsdp_trainer: Trainer,
tp_fsdp_trainer: Trainer,
check_grad: bool = False,
atol: Optional[float] = None,
rtol: Optional[float] = None
):

# Normally, we compare various models by their state_dict().
# However, calling `tp_fsdp_trainer.state.state_dict()` directly causes a NCCL timeout
Expand All @@ -308,9 +325,11 @@ def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer:
fsdp_params = _replace_state_dict_name(fsdp_params, '_fsdp_wrapped_module.', '')
tp_fsdp_params = _replace_state_dict_name(tp_fsdp_params, '_fsdp_wrapped_module.', '')

_compare_modules(ddp_params, fsdp_params, check_grad=check_grad)
_compare_modules(tp_fsdp_params, fsdp_params, check_grad=check_grad)
_compare_modules(ddp_params, fsdp_params, check_grad=check_grad)
ic(ddp_params, fsdp_params, tp_fsdp_params)

compare_modules(ddp_params, fsdp_params, check_grad=check_grad, atol=atol, rtol=rtol)
compare_modules(tp_fsdp_params, fsdp_params, check_grad=check_grad, atol=atol, rtol=rtol)
compare_modules(ddp_params, fsdp_params, check_grad=check_grad, atol=atol, rtol=rtol)


@contextmanager
Expand All @@ -333,14 +352,16 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:

@pytest.mark.gpu
@world_size(4)
@pytest.mark.parametrize('replication', [0, 2])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_forwards_backwards(world_size: int, replication: int = 0):
def test_tp_forwards_backwards_correctness(world_size: int, replication: int):
"""Test that training with DDP, FSDP, TP-FSDP results in the same:
- initial weights
- forward pass
- gradients
- updated weights
for a single step.
"""

# Initialize trainers with DDP, FSDP, TP-FSDP
Expand All @@ -359,9 +380,9 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0):
# Ensure output of the forward pass is the same
with FSDP.summon_full_params(fsdp_trainer.state.model):
with FSDP.summon_full_params(tp_fsdp_trainer.state.model):
_compare_modules({'': ddp_out}, {'': fsdp_out})
_compare_modules({'': ddp_out}, {'': tp_fsdp_out})
_compare_modules({'': fsdp_out}, {'': tp_fsdp_out})
compare_modules({'': ddp_out}, {'': fsdp_out})
compare_modules({'': ddp_out}, {'': tp_fsdp_out})
compare_modules({'': fsdp_out}, {'': tp_fsdp_out})

# Compute gradients
torch.sum(ddp_out).backward()
Expand All @@ -388,68 +409,17 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0):

@pytest.mark.gpu
@world_size(4)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fit_weights(world_size: int):
"""Test that DDP, FSDP, TP-FSDP have the same weights after calling fit, i.e. forward, backward pass."""
# from icecream import ic

# DDP gradients
print('ddp_trainer')
ddp_trainer = get_ddp_trainer()
ddp_trainer.fit()
ddp_state_dict = ddp_trainer.state.state_dict()
print(f'{ddp_state_dict=}')
ddp_trainer.close()

# FSDP gradients
print('fsdp_trainer')
fsdp_trainer = get_fsdp_trainer()
fsdp_trainer.fit()
fsdp_state_dict = fsdp_trainer.state.state_dict()
print(f'{fsdp_state_dict=}')
fsdp_state_dict_2 = get_mono_state_dict_from_sharded_one(fsdp_trainer)
print(f'{fsdp_state_dict_2=}')
fsdp_trainer.close()

# TP-FSDP gradients
print('tp_fsdp_trainer')
tp_fsdp_trainer = get_tp_fsdp_trainer()
tp_fsdp_trainer.fit()
tp_fsdp_state_dict = tp_fsdp_trainer.state.state_dict()
print(f'{tp_fsdp_state_dict=}')
tp_fsdp_state_dict_2 = get_mono_state_dict_from_sharded_one(tp_fsdp_trainer)
print(f'{tp_fsdp_state_dict_2=}')
tp_fsdp_trainer.close()

# for name, param in tp_fsdp_trainer.state.model.named_parameters():
# if param.grad is not None:
# print(name, param.grad.shape, param.grad)

# if dist.get_local_rank() == 0:
# pass

# todo:
#! reaname keys, e.g. module.2.weight -> fc2.weight
#! compare model, optimizer states
#! use _compare_optims_between_state_dicts, _compare_model_params_between_state_dicts from test_fsdp_checkpoint

# removes 'module.' from all state dict keys in-place
# consume_prefix_in_state_dict_if_present(tp_fsdp_state_dict_2['model'], 'module.')
# print(f'{tp_fsdp_state_dict_2=}')
# consume_prefix_in_state_dict_if_present(tp_fsdp_state_dict_2['optimizers'], 'module.')
# print(f'{tp_fsdp_state_dict_2=}')

# assert fsdp_state_dict_2 == tp_fsdp_state_dict_2


@pytest.mark.gpu
@world_size(4)
@pytest.mark.parametrize('replication', [0, 2])
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fit(world_size: int, batch_size: int, replication: int = 0):
"""Test that DDP, FSDP, TP-FSDP have the same trainer.fit(), i.e. output the same loss and accuracy."""
def test_tp_fit_correctness(world_size: int, batch_size: int, replication: int):
"""Test that training with DDP, FSDP, TP-FSDP results in the same:
- updated weights
- loss
- accuracy
after training for multiple steps via trainer.fit().
"""

# Initialize
train_steps = 20 # number of steps to train for
Expand All @@ -474,52 +444,51 @@ def test_tp_fit(world_size: int, batch_size: int, replication: int = 0):
tp_fsdp_trainer.close()
tp_fsdp_stats = get_stats(tp_fsdp_trainer)

# # Ensure the updated models weights are the same
# # We expect this test to fail without replication, i.e. when replication=0
# error_error_msg = 'Parameters are not close enough:*'
# with fail_without_replication(replication, AssertionError, error_error_msg):
# compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)

# Compare loss between DDP, FSDP, TP-FSDP
np.testing.assert_allclose(
ddp_stats['loss_array'],
fsdp_stats['loss_array'],
atol=6e-2,
err_msg='Loss arrays of DDP and FSDP are not close enough.',
)
np.testing.assert_allclose(
ddp_stats['loss_array'],
tp_fsdp_stats['loss_array'],
atol=6e-2,
err_msg='Loss arrays of DDP and TP-FSDP are not close enough.',
)
np.testing.assert_allclose(
fsdp_stats['loss_array'],
tp_fsdp_stats['loss_array'],
atol=6e-2,
err_msg='Loss arrays of FSDP and TP-FSDP are not close enough.',
)
ic(ddp_stats, fsdp_stats, tp_fsdp_stats)

# Compare accuracy between DDP, FSDP, TP-FSDP
loss_atol = 1 / n_samples # can make a mistake on at most one sample
np.testing.assert_allclose(
ddp_stats['accuracy_array'],
fsdp_stats['accuracy_array'],
atol=loss_atol,
err_msg='Accuracy arrays of DDP and FSDP are not close enough',
)
np.testing.assert_allclose(
ddp_stats['accuracy_array'],
tp_fsdp_stats['accuracy_array'],
atol=loss_atol,
err_msg='Accuracy arrays of DDP and FSDP-TP are not close enough',
)
np.testing.assert_allclose(
fsdp_stats['accuracy_array'],
tp_fsdp_stats['accuracy_array'],
atol=loss_atol,
err_msg='Accuracy arrays of FSDP and FSDP-TP are not close enough',
)
# Ensure the updated models weights are the same
# Drop tolerance due to precision issues across different parallelism strategies
# We expect this test to fail without replication, i.e. when replication=0
error_error_msg = 'Parameters are not close enough:*'
with fail_without_replication(replication, AssertionError, error_error_msg):
compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, atol=1e-5, rtol=1e-3)

# Compare loss between DDP, FSDP, TP-FSDP
np.testing.assert_allclose(
ddp_stats['loss_array'],
fsdp_stats['loss_array'],
atol=6e-5,
err_msg='Loss arrays of DDP and FSDP are not close enough.',
)
np.testing.assert_allclose(
ddp_stats['loss_array'],
tp_fsdp_stats['loss_array'],
atol=6e-5,
err_msg='Loss arrays of DDP and TP-FSDP are not close enough.',
)
np.testing.assert_allclose(
fsdp_stats['loss_array'],
tp_fsdp_stats['loss_array'],
atol=6e-5,
err_msg='Loss arrays of FSDP and TP-FSDP are not close enough.',
)

# Compare accuracy between DDP, FSDP, TP-FSDP
np.testing.assert_allclose(
ddp_stats['accuracy_array'],
fsdp_stats['accuracy_array'],
err_msg='Accuracy arrays of DDP and FSDP are not close enough',
)
np.testing.assert_allclose(
ddp_stats['accuracy_array'],
tp_fsdp_stats['accuracy_array'],
err_msg='Accuracy arrays of DDP and FSDP-TP are not close enough',
)
np.testing.assert_allclose(
fsdp_stats['accuracy_array'],
tp_fsdp_stats['accuracy_array'],
err_msg='Accuracy arrays of FSDP and FSDP-TP are not close enough',
)


@world_size(4)
Expand Down Expand Up @@ -758,22 +727,10 @@ def test_tp_with_subset_of_params(world_size: int):

@world_size(4)
@pytest.mark.gpu
@pytest.mark.skip('This is broken.')
@pytest.mark.skip('This is broken due to https://github.com/pytorch/pytorch/issues/134095/.')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fsdp_state_dict(world_size: int):
tp_fsdp_trainer = get_tp_fsdp_trainer(replication=2)
tp_fsdp_state_dict1 = tp_fsdp_trainer.state.state_dict()
ic(tp_fsdp_state_dict1)
tp_fsdp_state_dict1 = tp_fsdp_trainer.state.state_dict() # work sometimes, fails sometimes
with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=True):
tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict()
ic(tp_fsdp_state_dict2)


if __name__ == '__main__':
import warnings
warnings.filterwarnings('ignore')

world_size = 4
batch_size = 4
test_tp_fit(world_size, batch_size, replication=2)
test_tp_fit(world_size, batch_size, replication=0)
tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict() # fails always

0 comments on commit 288d4f5

Please sign in to comment.