diff --git a/scripts/train/train.py b/scripts/train/train.py index 34c16dfea1..66e369cad0 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -114,6 +114,22 @@ def validate_config(cfg: DictConfig): f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' ) +def _initialize_gloo_and_nccl(dist_timeout: Union[int, float]): + # First, initialize with a gloo process group and test a barrier + log.debug('Initializing dist with cpu...') + dist.initialize_dist('cpu', timeout=dist_timeout) + log.debug('Testing barrier with cpu...') + dist.barrier() + log.debug('Barrier test passed with cpu. Destroying process group...') + tdist.destroy_process_group() + log.debug('Process group destroyed.') + + # Now, initialize with the correct device + log.debug('Initializing dist with device...') + dist.initialize_dist(get_device(None), timeout=dist_timeout) + log.debug('Testing barrier with device...') + dist.barrier() + log.debug('Barrier test passed with device.') def main(cfg: DictConfig) -> Trainer: # Run user provided code if specified @@ -188,21 +204,11 @@ def main(cfg: DictConfig) -> Trainer: logging.getLogger(__name__).setLevel( python_log_level.upper()) # Train script - # First, initialize with a gloo process group and test a barrier - log.debug('Initializing dist with cpu...') - dist.initialize_dist('cpu', timeout=dist_timeout) - log.debug('Testing barrier with cpu...') - dist.barrier() - log.debug('Barrier test passed with cpu. Destroying process group...') - tdist.destroy_process_group() - log.debug('Process group destroyed.') + # We have experienced an issue where the first barrier with NCCL does not timeout properly, + # and can hang forever if something is wrong. To attempt to mitigate this, we will first + # initialize with a gloo process group and test a barrier, then destroy the process group - # Now, initialize with the correct device - log.debug('Initializing dist with device...') - dist.initialize_dist(get_device(None), timeout=dist_timeout) - log.debug('Testing barrier with device...') - dist.barrier() - log.debug('Barrier test passed with device.') + _initialize_gloo_and_nccl(dist_timeout=dist_timeout) # Get global and device batch size information from distributed/single node setting cfg = update_batch_size_info(cfg)