Skip to content

Commit

Permalink
Merge branch 'main' into milo/foundry-type-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 24, 2024
2 parents bf21b14 + 76f74b6 commit c16e359
Showing 1 changed file with 42 additions and 14 deletions.
56 changes: 42 additions & 14 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed
from composer import ComposerModel, Trainer
from composer.core.callback import Callback
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
Expand Down Expand Up @@ -143,6 +144,33 @@ def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]):
})


def _initialize_gloo_and_nccl(dist_timeout: Union[int, float]):
"""Initialize GLOO process group (then destroyed) and device process group.
We have experienced an issue where the first barrier with NCCL does not timeout properly,
and can hang forever if something is wrong. To attempt to mitigate this, we will first
initialize with a gloo process group and test a barrier, then destroy the process group
Args:
dist_timeout (Union[int, float]): Timeout for initializing the process group
"""
# First, initialize with a gloo process group and test a barrier
log.debug('Initializing dist with cpu...')
dist.initialize_dist('cpu', timeout=dist_timeout)
log.debug('Testing barrier with cpu...')
dist.barrier()
log.debug('Barrier test passed with cpu. Destroying process group...')
torch.distributed.destroy_process_group()
log.debug('Process group destroyed.')

# Now, initialize with the correct device
log.debug('Initializing dist with device...')
dist.initialize_dist(get_device(None), timeout=dist_timeout)
log.debug('Testing barrier with device...')
dist.barrier()
log.debug('Barrier test passed with device.')


def main(cfg: DictConfig) -> Trainer:
code_paths = cfg.get('code_paths', [])
# Import any user provided code
Expand Down Expand Up @@ -191,7 +219,20 @@ def main(cfg: DictConfig) -> Trainer:

# Initialize pytorch distributed training process groups
dist_timeout: Union[int, float] = train_cfg.dist_timeout
dist.initialize_dist(get_device(None), timeout=dist_timeout)

# Set logging level
logging.basicConfig(
# Example of format string
# 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('llmfoundry').setLevel(
python_log_level.upper()) # Foundry module
logging.getLogger(__name__).setLevel(
python_log_level.upper()) # Train script

_initialize_gloo_and_nccl(dist_timeout=dist_timeout)

# Mandatory model training configs
model_config = train_cfg.model
Expand Down Expand Up @@ -230,19 +271,6 @@ def main(cfg: DictConfig) -> Trainer:
'FSDP is not applicable for single-GPU training. Reverting to DDP.')
fsdp_config = None

# set logging level
if train_cfg.python_log_level is not None:
logging.basicConfig(
# Example of format string
# 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('llmfoundry').setLevel(
train_cfg.python_log_level.upper()) # Foundry module
logging.getLogger(__name__).setLevel(
train_cfg.python_log_level.upper()) # Train script

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
Expand Down

0 comments on commit c16e359

Please sign in to comment.