Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Apr 10, 2024
2 parents c07ad62 + 52776a7 commit 3606052
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
8 changes: 7 additions & 1 deletion composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
pass

# Fix memory leak for FSDP.optim_state_dict_to_load
# https://github.com/pytorch/pytorch/issues/116553
from torch.distributed.fsdp import _optim_utils

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
Expand Down
50 changes: 42 additions & 8 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,11 @@ def _get_process_group(pg, process_group_cache=None):
return process_group_cache[ranks]

log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.')

ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks)))
(
current_group,
_subgroups,
) = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup_list)
process_group = distributed.new_group(ranks)

if process_group_cache is not None:
process_group_cache[ranks] = current_group
return current_group
process_group_cache[ranks] = process_group
return process_group


def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict:
Expand Down Expand Up @@ -1124,3 +1119,42 @@ def _same_storage(a, b):
if isinstance(b, DTensor):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse(
torch.__version__,) < version.parse('2.2.9'):

from torch.distributed.fsdp._optim_utils import FSDPParamInfo
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict

@no_type_check
def _shard_orig_param_state(
fsdp_param_info: FSDPParamInfo,
fqn: str,
optim_state: Dict[str, Any],
) -> Dict[str, Any]:
if not optim_state:
return {}
fsdp_state = fsdp_param_info.state
flat_param = fsdp_param_info.handle.flat_param
param_idx = fsdp_param_info.param_indices[fqn]
shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
optim_state = _gather_state_dict(
optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device,
)
if not shard_param_info.in_shard:
return {}
# Flatten and shard the state.
new_optim_state: Dict[str, Any] = {}
intra_param_start_idx = shard_param_info.intra_param_start_idx
intra_param_end_idx = shard_param_info.intra_param_end_idx
for state_name, value in optim_state.items():
if (
torch.is_tensor(value)
and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
):
# This clone() is the patch to fix the OOM
# https://github.com/pytorch/pytorch/pull/117261
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
new_optim_state[state_name] = value
return new_optim_state

0 comments on commit 3606052

Please sign in to comment.