From 8bc2e1a33148c8691fbb9c6e981f7ceec9111b8b Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 9 Apr 2024 20:40:54 -0400 Subject: [PATCH] Enabling the computation of validation loss and other metrics when using sequence parallelism (#3183) * fix a bug in eval with seq parallelism * print debug values * .. * .. * .. * potentially fixing the eval bug * minor * minor * minor * .. * fixing is_sampler_distributed * removing redundant condition --- composer/core/data_spec.py | 2 +- composer/core/evaluator.py | 4 +++- composer/trainer/trainer.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index c0bdea5ea6..2289fa3ec8 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -221,7 +221,7 @@ def __init__( world_size = dist.get_world_size() # Check for Distributed Sampler if not using IterableDataset on more than 1 GPU if world_size > 1 and not isinstance(dataloader.dataset, torch.utils.data.IterableDataset): - is_sampler_distributed = dataloader.sampler and isinstance(dataloader.sampler, DistributedSampler) + is_sampler_distributed = isinstance(dataloader.sampler, DistributedSampler) is_batch_sampler_distributed = dataloader.batch_sampler is not None and isinstance( dataloader.batch_sampler, DistributedSampler, diff --git a/composer/core/evaluator.py b/composer/core/evaluator.py index 4a585543bc..4cf943acac 100644 --- a/composer/core/evaluator.py +++ b/composer/core/evaluator.py @@ -94,6 +94,8 @@ def __init__( self._eval_interval = None self.eval_interval = eval_interval self.auto_microbatching = _is_auto_microbatching(device_eval_microbatch_size) + if self.auto_microbatching and hasattr(self.dataloader, 'seq_parallel_world_size'): + raise ValueError('`device_eval_microbatch_size="auto"` is not compatible with sequence parallelism.') self.device_eval_microbatch_size = _get_initial_device_eval_microbatch_size( device_eval_microbatch_size, self.auto_microbatching, @@ -177,7 +179,7 @@ def _get_initial_device_eval_microbatch_size( ), ) from e return batch_size - elif isinstance(device_eval_microbatch_size, int): + elif isinstance(device_eval_microbatch_size, Union[int, float]): return device_eval_microbatch_size else: raise ValueError("device_eval_microbatch_size must be an int or ``'auto'``") diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4020f76559..3827533e51 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -409,7 +409,7 @@ def _validate_evaluator(evaluator: Evaluator, device: Device): if hasattr( evaluator.dataloader, 'seq_parallel_world_size', - ) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore + ) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore raise ValueError( 'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.', )