diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index df60a34c7f..279b1434f4 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -177,11 +177,16 @@ def _get_process_group(pg, process_group_cache=None): return process_group_cache[ranks] log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') - process_group = distributed.new_group(ranks) + + ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) + ( + current_group, + _subgroups, + ) = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup_list) if process_group_cache is not None: - process_group_cache[ranks] = process_group - return process_group + process_group_cache[ranks] = current_group + return current_group def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict: