Skip to content

Commit

Permalink
[not to be merged yet] added temp changes for fp32 main grad, might n…
Browse files Browse the repository at this point in the history
…ot work for TE (#1151)

* added temp changes for fp32 main grad, might not work for TE

* post rebase

* changes to keep reduced grad in fp32 (#1152)

* fix .grad=None issue when param is not sharded (#1153)

* fixed broken clipping (#1154)

Co-authored-by: Naman Goyal <[email protected]>

---------

Co-authored-by: Naman Goyal <[email protected]>
Co-authored-by: Vedanuj Goswami <[email protected]>
Co-authored-by: Jiecao Yu <[email protected]>
  • Loading branch information
4 people authored Dec 7, 2023
1 parent a8189f0 commit 3b7cc24
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def _cast_buffers(
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None]
return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)]

@torch.no_grad()
def clip_grad_norm_(
Expand Down Expand Up @@ -1714,30 +1714,48 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.unsharded_main_grad.add_(param.grad.data)

param.grad = None

if not self._require_backward_grad_sync:
return

# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data

if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()
orig_grad_data = param.unsharded_main_grad.data
else:
orig_grad_data = param.grad.data

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
if getattr(param, "unsharded_main_grad", None) is not None:
param.unsharded_main_grad.data.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if getattr(param, "unsharded_main_grad", None) is not None:
grad = param.unsharded_main_grad.data
param.unsharded_main_grad = None
else:
grad = param.grad.data
param.grad = None

# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand All @@ -1749,7 +1767,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param)
self._reducer.reduce_scatter_async(
grad, group=self.process_group_reduce_scatter, callback_fn=callback_fn
Expand All @@ -1759,7 +1776,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
self._post_reduction_hook(param, param.grad)
if getattr(param, "unsharded_main_grad", None) is not None:
self._post_reduction_hook(param, param.unsharded_main_grad)
else:
self._post_reduction_hook(param, param.grad)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand All @@ -1785,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.fp32_reduce_scatter:
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())

Expand All @@ -1799,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
param._saved_grad_shard.data += reduced_grad.data
reduced_grad = param._saved_grad_shard.data
elif (param.grad is None) and self.fp32_reduce_scatter:
param.main_grad = reduced_grad.data

# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
Expand Down Expand Up @@ -1887,7 +1909,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
if p.shape != p._saved_grad_shard.shape:
self._use_fp32_param_shard([p])
if p._saved_grad_shard.dtype != p.dtype:
p.grad = p._saved_grad_shard.to(p.dtype)
p.main_grad = p._saved_grad_shard
else:
p.grad = p._saved_grad_shard

Expand Down

0 comments on commit 3b7cc24

Please sign in to comment.