Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the FSDP.optim_state_dict_to_load OOM #3184

Merged
merged 12 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
pass

# Fix memory leak for FSDP.optim_state_dict_to_load
# https://github.com/pytorch/pytorch/issues/116553
from torch.distributed.fsdp import _optim_utils

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
Expand Down
43 changes: 43 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,3 +1124,46 @@ def _same_storage(a, b):
if isinstance(b, DTensor):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse(
torch.__version__,) < version.parse('2.2.9'):

from torch.distributed.fsdp._optim_utils import FSDPParamInfo
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict

@no_type_check
def _shard_orig_param_state(
fsdp_param_info: FSDPParamInfo,
fqn: str,
optim_state: Dict[str, Any],
) -> Dict[str, Any]:
"""
Shard the optimizer state for the original parameter with the name ``fqn``.
This API should only be used when ``use_orig_params`` is True.
bigning marked this conversation as resolved.
Show resolved Hide resolved
"""
if not optim_state:
return {}
fsdp_state = fsdp_param_info.state
flat_param = fsdp_param_info.handle.flat_param
param_idx = fsdp_param_info.param_indices[fqn]
shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
optim_state = _gather_state_dict(
optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device
)
if not shard_param_info.in_shard:
return {}
# Flatten and shard the state.
new_optim_state: Dict[str, Any] = {}
intra_param_start_idx = shard_param_info.intra_param_start_idx
intra_param_end_idx = shard_param_info.intra_param_end_idx
for state_name, value in optim_state.items():
if (
torch.is_tensor(value)
and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
):
# This clone() is the patch to fix the OOM
# https://github.com/pytorch/pytorch/pull/117261
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
bigning marked this conversation as resolved.
Show resolved Hide resolved
new_optim_state[state_name] = value
return new_optim_state
Loading