Skip to content

Commit

Permalink
Add main_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyuh committed Oct 2, 2023
1 parent f3ae46e commit ad54660
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
if self.fp32_reduce_scatter:
if getattr(param, "main_grad", None) is None:
param.main_grad = param.grad.to(torch.float32)
else:
param.main_grad.add_(param.grad.data)

param.grad = None

if not self._require_backward_grad_sync:
return
Expand All @@ -1721,23 +1728,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data

if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()

orig_grad_data = param.grad.data

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
if getattr(param, "main_grad", None) is not None:
param.main_grad.data.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if getattr(param, "main_grad", None) is not None:
grad = param.main_grad.data
param.main_grad = None
else:
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand Down

0 comments on commit ad54660

Please sign in to comment.