Skip to content

Commit

Permalink
Fixing error condition check when device microbatch size times seq pa…
Browse files Browse the repository at this point in the history
…rallelism dim is not 1 due to floating point precision (mosaicml#3200)

* ..

* ..

* lint
  • Loading branch information
ShashankMosaicML committed Apr 19, 2024
1 parent 960654d commit 7b25d90
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def _validate_evaluator(evaluator: Evaluator, device: Device):
if hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
) and evaluator.dataloader.seq_parallel_world_size > 1 and abs( # type: ignore
evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size - 1, # type: ignore
) > 1e-4:
raise ValueError(
'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down Expand Up @@ -1126,7 +1128,9 @@ def __init__(
if train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
) and train_dataloader.seq_parallel_world_size > 1 and abs( # type: ignore
device_train_microbatch_size * train_dataloader.seq_parallel_world_size - 1, # type: ignore
) > 1e-4:
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down Expand Up @@ -2181,7 +2185,9 @@ def fit(
if train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
) and train_dataloader.seq_parallel_world_size > 1 and abs( # type: ignore
device_train_microbatch_size * train_dataloader.seq_parallel_world_size - 1, # type: ignore
) > 1e-4:
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down

0 comments on commit 7b25d90

Please sign in to comment.