From 8f82e2b0136195c6d48c4917dad1e7c8c77d7c1e Mon Sep 17 00:00:00 2001 From: Joe Early Date: Fri, 21 Jun 2024 10:46:44 +0100 Subject: [PATCH] Add unit test --- tests/trainer/test_trainer.py | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59e8b26782..0ee0fabb5d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1251,6 +1251,42 @@ def test_accumulate_time_across_ranks( assert batch_time_accum == datetime.timedelta(seconds=0.1 * (1 + 0)) + # @pytest.mark.gpu + # @pytest.mark.world_size(2) + def test_rank_dependent_dataloader_lengths( + self, + model: ComposerModel, + max_duration: Time[int], + ): + # Change rank 1 dataloader size to create different sized dataloaders on each rank + batch_size = 2 + num_samples = 10 + if dist.get_local_rank() == 1: + num_samples += 4 + # Create train and eval dataloaders (will have rank-dependent lengths) + train_dataset = RandomClassificationDataset(size=num_samples) + train_dataloader = DataLoader( + dataset=train_dataset, batch_size=batch_size, sampler=dist.get_sampler(train_dataset) + ) + eval_dataset = RandomClassificationDataset(size=num_samples) + eval_dataloader = DataLoader( + dataset=eval_dataset, batch_size=batch_size, sampler=dist.get_sampler(eval_dataset) + ) + # Fit (train + eval) + trainer = Trainer( + model=model, + max_duration=max_duration, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + ) + trainer.fit() + # Check the correct number of samples and batches have been processed + assert trainer.state.timestamp.sample.value == num_samples + assert trainer.state.timestamp.batch.value == math.ceil(num_samples / batch_size) + assert trainer.state.eval_timestamp.sample.value == num_samples + assert trainer.state.eval_timestamp.batch.value == math.ceil(num_samples / batch_size) + + @world_size(1, 2) @device('cpu', 'gpu', 'gpu-amp', precision=True) class TestTrainerEquivalence():