diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index f78038b307..c6ceebb72f 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -63,7 +63,13 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.2.9'): # Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently - pass + + # Fix memory leak for FSDP.optim_state_dict_to_load + # https://github.com/pytorch/pytorch/issues/116553 + from torch.distributed.fsdp import _optim_utils + + from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state + _optim_utils._shard_orig_param_state = _shard_orig_param_state elif version.parse(torch.__version__) < version.parse('2.3.1'): # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 38458f2227..df60a34c7f 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -1119,3 +1119,42 @@ def _same_storage(a, b): if isinstance(b, DTensor): b = b._local_tensor return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() + +if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse( + torch.__version__,) < version.parse('2.2.9'): + + from torch.distributed.fsdp._optim_utils import FSDPParamInfo + from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict + + @no_type_check + def _shard_orig_param_state( + fsdp_param_info: FSDPParamInfo, + fqn: str, + optim_state: Dict[str, Any], + ) -> Dict[str, Any]: + if not optim_state: + return {} + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + param_idx = fsdp_param_info.param_indices[fqn] + shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] + optim_state = _gather_state_dict( + optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device, + ) + if not shard_param_info.in_shard: + return {} + # Flatten and shard the state. + new_optim_state: Dict[str, Any] = {} + intra_param_start_idx = shard_param_info.intra_param_start_idx + intra_param_end_idx = shard_param_info.intra_param_end_idx + for state_name, value in optim_state.items(): + if ( + torch.is_tensor(value) + and value.dim() > 0 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + # This clone() is the patch to fix the OOM + # https://github.com/pytorch/pytorch/pull/117261 + value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] + new_optim_state[state_name] = value + return new_optim_state