From 05e568bccd9058f855e279417725609236788e34 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Fri, 12 Apr 2024 09:48:24 -0700 Subject: [PATCH] Revert "Update mosaic_fsdp_utils.py (#3185)" (#3187) This reverts commit 2a262b4e469e306a792f127705577f13425350ee. --- composer/trainer/mosaic_fsdp_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index df60a34c7f..279b1434f4 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -177,11 +177,16 @@ def _get_process_group(pg, process_group_cache=None): return process_group_cache[ranks] log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') - process_group = distributed.new_group(ranks) + + ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) + ( + current_group, + _subgroups, + ) = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup_list) if process_group_cache is not None: - process_group_cache[ranks] = process_group - return process_group + process_group_cache[ranks] = current_group + return current_group def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict: