From baf9908a4bed67011caa3c2283f3c75f0bd5cf6f Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:14:59 -0700 Subject: [PATCH] Update mosaic_fsdp_utils.py (#3185) Porting https://github.com/mosaicml/llm-foundry/pull/1104 to composer. --- composer/trainer/mosaic_fsdp_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 365f474d4b..38458f2227 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -177,16 +177,11 @@ def _get_process_group(pg, process_group_cache=None): return process_group_cache[ranks] log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') - - ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) - ( - current_group, - _subgroups, - ) = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup_list) + process_group = distributed.new_group(ranks) if process_group_cache is not None: - process_group_cache[ranks] = current_group - return current_group + process_group_cache[ranks] = process_group + return process_group def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict: