Skip to content

Commit

Permalink
Fix the FSDP.optim_state_dict_to_load OOM (#3184)
Browse files Browse the repository at this point in the history
* up

* up

* up

* a

* a

* up

* up

* comments

* up

* lint

* line
  • Loading branch information
bigning authored Apr 10, 2024
1 parent 2a262b4 commit 52776a7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
8 changes: 7 additions & 1 deletion composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
pass

# Fix memory leak for FSDP.optim_state_dict_to_load
# https://github.com/pytorch/pytorch/issues/116553
from torch.distributed.fsdp import _optim_utils

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
Expand Down
39 changes: 39 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,42 @@ def _same_storage(a, b):
if isinstance(b, DTensor):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse(
torch.__version__,) < version.parse('2.2.9'):

from torch.distributed.fsdp._optim_utils import FSDPParamInfo
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict

@no_type_check
def _shard_orig_param_state(
fsdp_param_info: FSDPParamInfo,
fqn: str,
optim_state: Dict[str, Any],
) -> Dict[str, Any]:
if not optim_state:
return {}
fsdp_state = fsdp_param_info.state
flat_param = fsdp_param_info.handle.flat_param
param_idx = fsdp_param_info.param_indices[fqn]
shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
optim_state = _gather_state_dict(
optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device,
)
if not shard_param_info.in_shard:
return {}
# Flatten and shard the state.
new_optim_state: Dict[str, Any] = {}
intra_param_start_idx = shard_param_info.intra_param_start_idx
intra_param_end_idx = shard_param_info.intra_param_end_idx
for state_name, value in optim_state.items():
if (
torch.is_tensor(value)
and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
):
# This clone() is the patch to fix the OOM
# https://github.com/pytorch/pytorch/pull/117261
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
new_optim_state[state_name] = value
return new_optim_state

0 comments on commit 52776a7

Please sign in to comment.