Skip to content

Commit

Permalink
better comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 20, 2024
1 parent 0dac98a commit 895f08e
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,11 @@ def compare_models(
def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:
logger = trainer.logger.destinations[0]
stats = {
'loss_array': logger.get_timeseries('loss/train/total')['loss/train/total'], # type: ignore
'accuracy_array': logger.get_timeseries('metrics/train/MulticlassAccuracy')
['metrics/train/MulticlassAccuracy'], # type: ignore
'loss_array':
logger.get_timeseries('loss/train/total')['loss/train/total'], # type: ignore
'accuracy_array':
logger.get_timeseries('metrics/train/MulticlassAccuracy') # type: ignore
['metrics/train/MulticlassAccuracy'],
}
return stats

Expand All @@ -313,7 +315,7 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]:
@pytest.mark.parametrize('replication', [2])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_forwards_backwards(world_size: int, replication: int):
def test_tp_forwards_backwards_correctness(world_size: int, replication: int):
"""Test that training with DDP, FSDP, TP-FSDP results in the same:
- initial weights
- forward pass
Expand Down Expand Up @@ -365,7 +367,7 @@ def test_tp_forwards_backwards(world_size: int, replication: int):
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+')
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
def test_tp_fit(world_size: int, batch_size: int, replication: int):
def test_tp_fit_correctness(world_size: int, batch_size: int, replication: int):
"""Test that training with DDP, FSDP, TP-FSDP results in the same:
- updated weights
- loss
Expand Down Expand Up @@ -446,9 +448,9 @@ def test_tp_fit(world_size: int, batch_size: int, replication: int):
def test_tp_train(world_size: int):
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

# Normally, each TP rank receives the same data via data replication
# In this test, we do not do this: each TP rank gets different data
# This is okay - we are testing the TP mechanism, not actual TP correctness
# For TP to produce the correct result, each TP rank receives the same data
# In this test, TP ranks receive different data as we are testing the TP
# mechanism, not actual TP correctness.
model = SimpleModel()
dataset = RandomClassificationDataset(size=8)
dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset))
Expand Down Expand Up @@ -478,9 +480,9 @@ def test_tp_train(world_size: int):
def test_tp_with_param_groups(world_size: int):
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

# Normally, each TP rank receives the same data via data replication
# In this test, we do not do this: each TP rank gets different data
# This is okay - we are testing the TP mechanism, not actual TP correctness
# For TP to produce the correct result, each TP rank receives the same data
# In this test, TP ranks receive different data as we are testing the TP
# mechanism, not actual TP correctness.
model = SimpleModel()
dataset = RandomClassificationDataset(size=8)
dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset))
Expand Down Expand Up @@ -519,9 +521,9 @@ def test_tp_with_param_groups(world_size: int):
def test_tp_with_subset_of_params(world_size: int):
from torch.distributed.tensor.parallel import ColwiseParallel

# Normally, each TP rank receives the same data via data replication
# In this test, we do not do this: each TP rank gets different data
# This is okay - we are testing the TP mechanism, not actual TP correctness
# For TP to produce the correct result, each TP rank receives the same data
# In this test, TP ranks receive different data as we are testing the TP
# mechanism, not actual TP correctness.
model = SimpleModel()
dataset = RandomClassificationDataset(size=8)
dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset))
Expand Down

0 comments on commit 895f08e

Please sign in to comment.