Skip to content

Commit

Permalink
Enabling the computation of validation loss and other metrics when us…
Browse files Browse the repository at this point in the history
…ing sequence parallelism (#3183)

* fix a bug in eval with seq parallelism

* print debug values

* ..

* ..

* ..

* potentially fixing the eval bug

* minor

* minor

* minor

* ..

* fixing is_sampler_distributed

* removing redundant condition
  • Loading branch information
ShashankMosaicML authored and Chuck Tang committed May 16, 2024
1 parent af1b353 commit 8bc2e1a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
world_size = dist.get_world_size()
# Check for Distributed Sampler if not using IterableDataset on more than 1 GPU
if world_size > 1 and not isinstance(dataloader.dataset, torch.utils.data.IterableDataset):
is_sampler_distributed = dataloader.sampler and isinstance(dataloader.sampler, DistributedSampler)
is_sampler_distributed = isinstance(dataloader.sampler, DistributedSampler)
is_batch_sampler_distributed = dataloader.batch_sampler is not None and isinstance(
dataloader.batch_sampler,
DistributedSampler,
Expand Down
4 changes: 3 additions & 1 deletion composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
self._eval_interval = None
self.eval_interval = eval_interval
self.auto_microbatching = _is_auto_microbatching(device_eval_microbatch_size)
if self.auto_microbatching and hasattr(self.dataloader, 'seq_parallel_world_size'):
raise ValueError('`device_eval_microbatch_size="auto"` is not compatible with sequence parallelism.')
self.device_eval_microbatch_size = _get_initial_device_eval_microbatch_size(
device_eval_microbatch_size,
self.auto_microbatching,
Expand Down Expand Up @@ -177,7 +179,7 @@ def _get_initial_device_eval_microbatch_size(
),
) from e
return batch_size
elif isinstance(device_eval_microbatch_size, int):
elif isinstance(device_eval_microbatch_size, Union[int, float]):
return device_eval_microbatch_size
else:
raise ValueError("device_eval_microbatch_size must be an int or ``'auto'``")
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _validate_evaluator(evaluator: Evaluator, device: Device):
if hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down

0 comments on commit 8bc2e1a

Please sign in to comment.