From 76f74b6968f5796dab47ed60df2ec08d9226689f Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:19:17 -0700 Subject: [PATCH] First initialize dist with gloo (#1133) --- scripts/train/train.py | 64 ++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index a49ae4e26d..cf0f34ebb5 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Union import torch +import torch.distributed from composer import Trainer from composer.core.callback import Callback from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, @@ -114,6 +115,33 @@ def validate_config(cfg: DictConfig): ) +def _initialize_gloo_and_nccl(dist_timeout: Union[int, float]): + """Initialize GLOO process group (then destroyed) and device process group. + + We have experienced an issue where the first barrier with NCCL does not timeout properly, + and can hang forever if something is wrong. To attempt to mitigate this, we will first + initialize with a gloo process group and test a barrier, then destroy the process group + + Args: + dist_timeout (Union[int, float]): Timeout for initializing the process group + """ + # First, initialize with a gloo process group and test a barrier + log.debug('Initializing dist with cpu...') + dist.initialize_dist('cpu', timeout=dist_timeout) + log.debug('Testing barrier with cpu...') + dist.barrier() + log.debug('Barrier test passed with cpu. Destroying process group...') + torch.distributed.destroy_process_group() + log.debug('Process group destroyed.') + + # Now, initialize with the correct device + log.debug('Initializing dist with device...') + dist.initialize_dist(get_device(None), timeout=dist_timeout) + log.debug('Testing barrier with device...') + dist.barrier() + log.debug('Barrier test passed with device.') + + def main(cfg: DictConfig) -> Trainer: # Run user provided code if specified code_paths = pop_config(cfg, @@ -170,7 +198,24 @@ def main(cfg: DictConfig) -> Trainer: 'dist_timeout', must_exist=False, default_value=600.0) - dist.initialize_dist(get_device(None), timeout=dist_timeout) + python_log_level: Optional[str] = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') + # Set logging level + if python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('llmfoundry').setLevel( + python_log_level.upper()) # Foundry module + logging.getLogger(__name__).setLevel( + python_log_level.upper()) # Train script + + _initialize_gloo_and_nccl(dist_timeout=dist_timeout) # Get global and device batch size information from distributed/single node setting cfg = update_batch_size_info(cfg) @@ -298,10 +343,6 @@ def main(cfg: DictConfig) -> Trainer: 'log_to_console', must_exist=False, default_value=True) - python_log_level: Optional[str] = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') console_log_interval: Union[int, str] = pop_config(cfg, 'console_log_interval', must_exist=False, @@ -391,19 +432,6 @@ def main(cfg: DictConfig) -> Trainer: 'FSDP is not applicable for single-GPU training. Reverting to DDP.') fsdp_config = None - # set logging level - if python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' - ) - logging.getLogger('llmfoundry').setLevel( - python_log_level.upper()) # Foundry module - logging.getLogger(__name__).setLevel( - python_log_level.upper()) # Train script - # Initialize context init_context = process_init_device(model_config, fsdp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)