diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 1f1d29ea8..c82fd9443 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1765,9 +1765,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.fp32_reduce_scatter: if self.optimize_backward_concat: - param.unsharded_main_grad = self._fsdp_wrapped_module.fp32_grads + param.unsharded_main_grad = self._fsdp_wrapped_module.fp32_flat_grad # Clean up accumulated grads between data batches - self._fsdp_wrapped_module.fp32_grads = None + self._fsdp_wrapped_module.fp32_flat_grad = None else: if getattr(param, "unsharded_main_grad", None) is None: param.unsharded_main_grad = param.grad.to(torch.float32) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index d35541e9a..727d08a6d 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -174,7 +174,7 @@ def __init__( self._require_backward_grad_sync = True # If optimize_backward_concat == True, used to accumulate the # fp32 gradients for the flattened parameters - self.fp32_grads = None + self.fp32_flat_grad = None # Handle param_list being None. if param_list is None: @@ -386,12 +386,12 @@ def _grad_accumulation_hook( end, ): """ - start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_grads - end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_grads + start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_flat_grad + end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_flat_grad """ - assert self.fp32_grads is not None - self.fp32_grads[start:end].add_(grad.flatten()) + assert self.fp32_flat_grad is not None + self.fp32_flat_grad[start:end].add_(grad.flatten()) return grad def _unflatten_params_as_views(self) -> None: @@ -421,7 +421,7 @@ def _unflatten_params_as_views(self) -> None: setattr(m, n, p) # This will set as plain attr if self.optimize_backward_concat: # Register post backward hook to accumulate the gradients - # in self.fp32_grads + # in self.fp32_flat_grad param_end = param_start + torch.numel(p) p.register_hook( functools.partial( @@ -433,10 +433,10 @@ def _unflatten_params_as_views(self) -> None: param_start = param_end param_views.append(p) - if self.optimize_backward_concat and self.fp32_grads is None: + if self.optimize_backward_concat and self.fp32_flat_grad is None: # Allocate GPU memory for flattened fp32 grad accumulation total_numels = sum([torch.numel(p) for p in param_views]) - self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device()) + self.fp32_flat_grad = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device()) # Save param views for easy access if anyone still wants to access