Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 24, 2024
1 parent 1111ec7 commit 62b0907
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,22 @@ def validate_config(cfg: DictConfig):
f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.'
)

def _initialize_gloo_and_nccl(dist_timeout: Union[int, float]):
# First, initialize with a gloo process group and test a barrier
log.debug('Initializing dist with cpu...')
dist.initialize_dist('cpu', timeout=dist_timeout)
log.debug('Testing barrier with cpu...')
dist.barrier()
log.debug('Barrier test passed with cpu. Destroying process group...')
tdist.destroy_process_group()
log.debug('Process group destroyed.')

# Now, initialize with the correct device
log.debug('Initializing dist with device...')
dist.initialize_dist(get_device(None), timeout=dist_timeout)
log.debug('Testing barrier with device...')
dist.barrier()
log.debug('Barrier test passed with device.')

def main(cfg: DictConfig) -> Trainer:
# Run user provided code if specified
Expand Down Expand Up @@ -188,21 +204,11 @@ def main(cfg: DictConfig) -> Trainer:
logging.getLogger(__name__).setLevel(
python_log_level.upper()) # Train script

# First, initialize with a gloo process group and test a barrier
log.debug('Initializing dist with cpu...')
dist.initialize_dist('cpu', timeout=dist_timeout)
log.debug('Testing barrier with cpu...')
dist.barrier()
log.debug('Barrier test passed with cpu. Destroying process group...')
tdist.destroy_process_group()
log.debug('Process group destroyed.')
# We have experienced an issue where the first barrier with NCCL does not timeout properly,
# and can hang forever if something is wrong. To attempt to mitigate this, we will first
# initialize with a gloo process group and test a barrier, then destroy the process group

# Now, initialize with the correct device
log.debug('Initializing dist with device...')
dist.initialize_dist(get_device(None), timeout=dist_timeout)
log.debug('Testing barrier with device...')
dist.barrier()
log.debug('Barrier test passed with device.')
_initialize_gloo_and_nccl(dist_timeout=dist_timeout)

# Get global and device batch size information from distributed/single node setting
cfg = update_batch_size_info(cfg)
Expand Down

0 comments on commit 62b0907

Please sign in to comment.