Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 16, 2024
1 parent b830740 commit 64a1c1d
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
world_size,
)


@pytest.mark.gpu
@world_size(4)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+')
Expand Down Expand Up @@ -149,7 +150,7 @@ def __init__(
def get_trainer(
parallelism_config: Optional[ParallelismConfig] = None,
size: int = 4,
batch_size: int = 4,
batch_size: int = 1,
num_classes: int = 2,
num_features: int = 6,
seed: int = 42,
Expand Down Expand Up @@ -180,13 +181,15 @@ def get_trainer(
parallelism_config=parallelism_config,
callbacks=[MemoryMonitor()],
loggers=[InMemoryLogger()],
progress_bar=False,
log_to_console=False,
)
return trainer


def get_ddp_trainer(
size: int = 4,
batch_size: int = 4,
batch_size: int = 1,
num_classes: int = 2,
num_features: int = 6,
seed: int = 42,
Expand All @@ -205,7 +208,7 @@ def get_ddp_trainer(

def get_fsdp_trainer(
size: int = 4,
batch_size: int = 4,
batch_size: int = 1,
num_classes: int = 2,
num_features: int = 6,
seed: int = 42,
Expand All @@ -232,7 +235,7 @@ def get_fsdp_trainer(

def get_tp_fsdp_trainer(
size: int = 4,
batch_size: int = 4,
batch_size: int = 1,
num_classes: int = 2,
num_features: int = 6,
seed: int = 42,
Expand Down Expand Up @@ -306,34 +309,38 @@ def test_tp_gradients(world_size: int):

# DDP gradients
ddp_trainer = get_ddp_trainer()
ddp_out = forward_pass(ddp_trainer)
torch.sum(ddp_out).backward()
ddp_trainer.state.optimizers[0].step()
ddp_trainer.fit()
# ddp_out = forward_pass(ddp_trainer)
# torch.sum(ddp_out).backward()
# ddp_trainer.state.optimizers[0].step()
ddp_trainer.close()
ddp_state_dict = ddp_trainer.state.state_dict()

ic('ddp_trainer')
ic(ddp_trainer.state.model.module)
for name, param in ddp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.grad.shape, param.grad)
print('ddp_trainer')
# ic(ddp_trainer.state.model.module)
# for name, param in ddp_trainer.state.model.named_parameters():
# if param.grad is not None:
# ic(name, param.grad.shape, param.grad)

# FSDP gradients
fsdp_trainer = get_fsdp_trainer()
fsdp_out = forward_pass(fsdp_trainer)
torch.sum(fsdp_out).backward()
fsdp_trainer.fit()
# fsdp_out = forward_pass(fsdp_trainer)
# torch.sum(fsdp_out).backward()
fsdp_trainer.close()
fsdp_state_dict = fsdp_trainer.state.state_dict()

ic('fsdp_trainer')
ic(fsdp_trainer.state.model.module)
for name, param in fsdp_trainer.state.model.named_parameters():
if param.grad is not None:
ic(name, param.grad.shape, param.grad)
print('fsdp_trainer')
# ic(fsdp_trainer.state.model.module)
# for name, param in fsdp_trainer.state.model.named_parameters():
# if param.grad is not None:
# ic(name, param.grad.shape, param.grad)

# TP-FSDP gradients
tp_fsdp_trainer = get_tp_fsdp_trainer()
tp_fsdp_out = forward_pass(tp_fsdp_trainer)
torch.sum(tp_fsdp_out).backward()
tp_fsdp_trainer.fit()
# tp_fsdp_out = forward_pass(tp_fsdp_trainer)
# torch.sum(tp_fsdp_out).backward()
tp_fsdp_trainer.close()
tp_fsdp_state_dict = tp_fsdp_trainer.state.state_dict()

Expand All @@ -355,27 +362,28 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:

@pytest.mark.gpu
@world_size(4)
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fit(world_size: int):
def test_tp_fit(batch_size: int, world_size: int):
"""Test that DDP, FSDP, TP-FSDP have the same trainer.fit(), i.e. output the same loss and accuracy."""

size = 1024 # make enough data to train for multiple steps

# DDP fit
ddp_trainer = get_ddp_trainer(size=size)
ddp_trainer = get_ddp_trainer(size=size, batch_size=batch_size)
ddp_trainer.fit()
ddp_trainer.close()
ddp_stats = get_stats(ddp_trainer)

# FSDP fit
fsdp_trainer = get_fsdp_trainer(size=size)
fsdp_trainer = get_fsdp_trainer(size=size, batch_size=batch_size)
fsdp_trainer.fit()
fsdp_trainer.close()
fsdp_stats = get_stats(fsdp_trainer)

# TP-FSDP fit
tp_fsdp_trainer = get_tp_fsdp_trainer(size=size)
tp_fsdp_trainer = get_tp_fsdp_trainer(size=size, batch_size=batch_size)
tp_fsdp_trainer.fit()
tp_fsdp_trainer.close()
tp_fsdp_stats = get_stats(tp_fsdp_trainer)
Expand Down

0 comments on commit 64a1c1d

Please sign in to comment.