diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 759b9f445..0eaa454f2 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) + if self.fp32_reduce_scatter: + if getattr(param, "main_grad", None) is None: + param.main_grad = param.grad.to(torch.float32) + else: + param.main_grad.add_(param.grad.data) + + param.grad = None if not self._require_backward_grad_sync: return @@ -1721,15 +1728,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.float() + orig_grad_data = param.grad.data + if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "main_grad", None) is not None: + param.main_grad.data.div_(self.gradient_predivide_factor) + else: + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1737,7 +1748,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if getattr(param, "main_grad", None) is not None: + grad = param.main_grad.data + param.main_grad = None + else: + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this