Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
JAEarly committed Jun 21, 2024
1 parent 6f07e83 commit 8f82e2b
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,42 @@ def test_accumulate_time_across_ranks(
assert batch_time_accum == datetime.timedelta(seconds=0.1 * (1 + 0))


# @pytest.mark.gpu
# @pytest.mark.world_size(2)
def test_rank_dependent_dataloader_lengths(
self,
model: ComposerModel,
max_duration: Time[int],
):
# Change rank 1 dataloader size to create different sized dataloaders on each rank
batch_size = 2
num_samples = 10
if dist.get_local_rank() == 1:
num_samples += 4
# Create train and eval dataloaders (will have rank-dependent lengths)
train_dataset = RandomClassificationDataset(size=num_samples)
train_dataloader = DataLoader(
dataset=train_dataset, batch_size=batch_size, sampler=dist.get_sampler(train_dataset)
)
eval_dataset = RandomClassificationDataset(size=num_samples)
eval_dataloader = DataLoader(
dataset=eval_dataset, batch_size=batch_size, sampler=dist.get_sampler(eval_dataset)
)
# Fit (train + eval)
trainer = Trainer(
model=model,
max_duration=max_duration,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)
trainer.fit()
# Check the correct number of samples and batches have been processed
assert trainer.state.timestamp.sample.value == num_samples
assert trainer.state.timestamp.batch.value == math.ceil(num_samples / batch_size)
assert trainer.state.eval_timestamp.sample.value == num_samples
assert trainer.state.eval_timestamp.batch.value == math.ceil(num_samples / batch_size)


@world_size(1, 2)
@device('cpu', 'gpu', 'gpu-amp', precision=True)
class TestTrainerEquivalence():
Expand Down

0 comments on commit 8f82e2b

Please sign in to comment.