diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index f5c024d5b..09648894c 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -687,7 +687,7 @@ def _cast_buffers( @property def params_with_grad(self) -> List[Parameter]: """[p for p in self.parameters() if p.grad is not None]""" - return [p for p in self.parameters() if p.grad is not None] + return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)] @torch.no_grad() def clip_grad_norm_( @@ -1714,6 +1714,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) + if self.fp32_reduce_scatter: + if getattr(param, "unsharded_main_grad", None) is None: + param.unsharded_main_grad = param.grad.to(torch.float32) + else: + param.unsharded_main_grad.add_(param.grad.data) + + param.grad = None + if not self._require_backward_grad_sync: return @@ -1721,15 +1729,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. - param.grad.data = param.grad.data.float() + orig_grad_data = param.unsharded_main_grad.data + else: + orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "unsharded_main_grad", None) is not None: + param.unsharded_main_grad.data.div_(self.gradient_predivide_factor) + else: + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1737,7 +1749,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if getattr(param, "unsharded_main_grad", None) is not None: + grad = param.unsharded_main_grad.data + param.unsharded_main_grad = None + else: + grad = param.grad.data + param.grad = None + # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this @@ -1749,7 +1767,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # This ensures the `default` stream will wait for the `post_backward` stream to complete the last # reduction for this module, before scheduling additional reduction work. Then at most there are two # unsharded gradients allocated; one for a pending reduction, and one for gradient computation. - param.grad = None callback_fn = functools.partial(self._post_reduction_hook, param) self._reducer.reduce_scatter_async( grad, group=self.process_group_reduce_scatter, callback_fn=callback_fn @@ -1759,7 +1776,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. assert self.world_size == 1 - self._post_reduction_hook(param, param.grad) + if getattr(param, "unsharded_main_grad", None) is not None: + self._post_reduction_hook(param, param.unsharded_main_grad) + else: + self._post_reduction_hook(param, param.grad) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for @@ -1785,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> # non-blocking. The downside is a bit more D2H transfer in that case. if self.fp32_reduce_scatter: orig_param_grad_data = reduced_grad.data - reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype) + # reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype) # Don't let this memory get reused until after the transfer. orig_param_grad_data.record_stream(torch.cuda.current_stream()) @@ -1799,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> ), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}" param._saved_grad_shard.data += reduced_grad.data reduced_grad = param._saved_grad_shard.data + elif (param.grad is None) and self.fp32_reduce_scatter: + param.main_grad = reduced_grad.data # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full # backwards pass completes, we will set `.grad` to the CPU copy. @@ -1887,7 +1909,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: if p.shape != p._saved_grad_shard.shape: self._use_fp32_param_shard([p]) if p._saved_grad_shard.dtype != p.dtype: - p.grad = p._saved_grad_shard.to(p.dtype) + p.main_grad = p._saved_grad_shard else: p.grad = p._saved_grad_shard