Skip to content

Commit

Permalink
support for grad acc
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoyal2707 committed Aug 30, 2024
1 parent d0b506f commit 0dfa5e5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,7 +1821,10 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
param._saved_grad_shard.data += reduced_grad.data
reduced_grad = param._saved_grad_shard.data
elif (param.grad is None) and self.fp32_reduce_scatter:
param.main_grad = reduced_grad.data
if getattr(param, "main_grad", None) is not None:
param.main_grad.add_(reduced_grad.data)
else:
param.main_grad = reduced_grad.data

# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
Expand Down

0 comments on commit 0dfa5e5

Please sign in to comment.