diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index ba92c8e99b..c7ef4be43a 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -1156,7 +1156,7 @@ def _save_checkpoint( ignore_keys(state_dict) # Ensure state exists state_dict['state'] = state_dict.get('state', {}) - ic('see') + if state.fsdp_sharded_state_dict_enabled and not weights_only: # Only rank 0 saves RNG if dist.get_global_rank() > 0: diff --git a/composer/utils/dist.py b/composer/utils/dist.py index b357c75ee2..3c984d4181 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -461,7 +461,7 @@ def all_gather_object(obj: TObj, group=None) -> list[TObj]: all_gather_object_list_hpu(obj_gather_list, obj, group=group) else: ic('before all_gather_object') - ic(obj_gather_list, obj, group) + ic(obj_gather_list, obj.keys(), group) dist.all_gather_object(obj_gather_list, obj, group=group) ic('after all_gather_object') # torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0