From 7b25d90902b922613e490856bebe534b79dc8a22 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Fri, 19 Apr 2024 12:59:25 -0400 Subject: [PATCH] Fixing error condition check when device microbatch size times seq parallelism dim is not 1 due to floating point precision (#3200) * .. * .. * lint --- composer/trainer/trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index bf40bcddde..bf296408f7 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -409,7 +409,9 @@ def _validate_evaluator(evaluator: Evaluator, device: Device): if hasattr( evaluator.dataloader, 'seq_parallel_world_size', - ) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore + ) and evaluator.dataloader.seq_parallel_world_size > 1 and abs( # type: ignore + evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size - 1, # type: ignore + ) > 1e-4: raise ValueError( 'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.', ) @@ -1126,7 +1128,9 @@ def __init__( if train_dataloader is not None and hasattr( train_dataloader, 'seq_parallel_world_size', - ) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore + ) and train_dataloader.seq_parallel_world_size > 1 and abs( # type: ignore + device_train_microbatch_size * train_dataloader.seq_parallel_world_size - 1, # type: ignore + ) > 1e-4: raise ValueError( '`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.', ) @@ -2181,7 +2185,9 @@ def fit( if train_dataloader is not None and hasattr( train_dataloader, 'seq_parallel_world_size', - ) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore + ) and train_dataloader.seq_parallel_world_size > 1 and abs( # type: ignore + device_train_microbatch_size * train_dataloader.seq_parallel_world_size - 1, # type: ignore + ) > 1e-4: raise ValueError( '`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.', )