Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added reshard hook for frozen params in backward #1159

Open
wants to merge 5 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from enum import Enum, auto
import functools
import itertools
import logging
from math import inf
import os
Expand Down Expand Up @@ -47,7 +48,6 @@
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
ProcessGroupName,
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
Expand Down Expand Up @@ -1457,6 +1457,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
self._register_post_backward_reshard_hooks(args, kwargs)

outputs = self.module(*args, **kwargs)

Expand Down Expand Up @@ -1655,6 +1656,37 @@ def _register_post_backward_hooks(self) -> None:
p._shard_bwd_hooks.append((grad_acc, handle))
# p._shard_bwd_hook = (grad_acc, handle)

def _register_post_backward_reshard_hooks(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if not hasattr(torch.autograd.graph, "register_multi_grad_hook"):
return # unsupported
if not torch.is_grad_enabled():
return
from torch.utils._pytree import tree_flatten
from torch.autograd.graph import register_multi_grad_hook
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
for param in self.params:
# Only register for parameters that do not require gradient
if param.requires_grad:
continue
if inp_tensors is None:
args_list, _ = tree_flatten(args)
kwargs_list, _ = tree_flatten(kwargs)
inp_tensors = [
obj
for obj in itertools.chain(args_list, kwargs_list)
if torch.is_tensor(obj) and obj.requires_grad
]
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(self._post_backward_reshard_hook, param)
)
if not hasattr(param, "_shard_bwd_hooks"):
param._shard_bwd_hooks = []
param._shard_bwd_hooks.append((hook_handle,))

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
"""
Expand Down Expand Up @@ -1697,12 +1729,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
if self._should_free_in_backward():
# Free full params.
self._free_full_params([param])

if self.mixed_precision:
Expand Down Expand Up @@ -1829,6 +1857,22 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
# Don't let this memory get reused until after the transfer.
reduced_grad.data.record_stream(torch.cuda.current_stream())

@torch.no_grad()
def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None:
if self._should_free_in_backward():
self._free_full_params([param])
if self.mixed_precision:
self._free_fp16_param_shard([param])
self._use_fp32_param_shard([param])

def _should_free_in_backward(self):
# As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
return self._require_backward_grad_sync or self.reshard_after_forward

def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.

Expand Down Expand Up @@ -1878,16 +1922,24 @@ def _wait_for_post_backward(self) -> None:
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()
for hook_state in p._shard_bwd_hooks:
if len(hook_state) == 1:
hook_state[0].remove()
elif len(hook_state) == 2:
hook_state[1].remove()
p._shard_bwd_hooks.clear()
if not p.requires_grad:
# For the 1st layer, if the forward inputs did not require
# gradient, then we cannot run a reshard hook for it, and
# we instead free here.
if p._full_param_padded.untyped_storage().size() > 0:
fsdp_module._post_backward_reshard_hook(p)
continue

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand Down
96 changes: 96 additions & 0 deletions tests/nn/data_parallel/test_fsdp_freezing_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from enum import Enum
from itertools import product
from unittest import mock
import copy
import tempfile

import pytest
Expand Down Expand Up @@ -275,3 +277,97 @@ def test_freezing_weights(temp_files, nested_trunk):
nprocs=world_size,
)
temp_file_idx += 3


@skip_if_single_gpu
def test_reshard_frozen_weights():
world_size = 2
for flatten_parameters, reshard_after_forward, inp_requires_grad in product(
[False, True], [False, True], [False, True]
):
print(
"Testing FSDP reshard frozen weights with "
f"flatten_parameters={flatten_parameters}, "
f"reshard_after_forward={reshard_after_forward}, "
f"inp_requires_grad={inp_requires_grad}"
)
mp.spawn(
_distributed_worker_reshard,
(world_size, flatten_parameters, reshard_after_forward, inp_requires_grad),
nprocs=world_size,
)


def _distributed_worker_reshard(
rank: int,
world_size: int,
flatten_parameters: bool,
reshard_after_forward: bool,
inp_requires_grad: bool,
):
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

torch.manual_seed(0)

num_linears = 6
modules = []
for _ in range(num_linears):
modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()]
model = nn.Sequential(*modules)
# Freeze every other linear
for i in range(num_linears):
if i % 2 == 0:
for param in model[i * 2].parameters(recurse=False):
param.requires_grad = False
num_frozen_linears = num_linears // 2

ref_model = DistributedDataParallel(copy.deepcopy(model), device_ids=[rank])
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)

for i, module in enumerate(model):
if isinstance(module, nn.Linear):
model[i] = FSDP(
module,
flatten_parameters=flatten_parameters,
reshard_after_forward=reshard_after_forward,
)
fsdp_model = FSDP(
model,
flatten_parameters=flatten_parameters,
reshard_after_forward=reshard_after_forward,
)
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2)

orig_post_backward_reshard_hook = FSDP._post_backward_reshard_hook
reshard_hook_count = 0

def post_backward_reshard_hook_with_count(*args, **kwargs):
nonlocal reshard_hook_count
reshard_hook_count += 1
return orig_post_backward_reshard_hook(*args, **kwargs)

with mock.patch(
"fairscale.nn.data_parallel.FullyShardedDataParallel._post_backward_reshard_hook",
post_backward_reshard_hook_with_count,
):
inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad)
for i in range(6):
losses = []
for model, optim in ((fsdp_model, fsdp_optim), (ref_model, ref_optim)):
optim.zero_grad()
loss = model(inp).sum()
losses.append(loss)
loss.backward()
optim.step()
expected_reshard_hook_count = num_frozen_linears
if not flatten_parameters:
expected_reshard_hook_count *= 2 # weight and bias per linear
assert (
reshard_hook_count == expected_reshard_hook_count
), f"Expected {expected_reshard_hook_count} but got {reshard_hook_count}"
assert losses[0].eq(losses[1]).all().item(), f"Expected {losses[1]} but got {losses[0]}"
reshard_hook_count = 0
Loading