Skip to content

Commit

Permalink
rename fp32_grads to fp32_flat_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed Jun 10, 2024
1 parent 54fe983 commit 759959d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,9 +1765,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.fp32_reduce_scatter:
if self.optimize_backward_concat:
param.unsharded_main_grad = self._fsdp_wrapped_module.fp32_grads
param.unsharded_main_grad = self._fsdp_wrapped_module.fp32_flat_grad
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = None
self._fsdp_wrapped_module.fp32_flat_grad = None
else:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
Expand Down
16 changes: 8 additions & 8 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self._require_backward_grad_sync = True
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = None
self.fp32_flat_grad = None

# Handle param_list being None.
if param_list is None:
Expand Down Expand Up @@ -386,12 +386,12 @@ def _grad_accumulation_hook(
end,
):
"""
start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_grads
end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_grads
start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_flat_grad
end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_flat_grad
"""

assert self.fp32_grads is not None
self.fp32_grads[start:end].add_(grad.flatten())
assert self.fp32_flat_grad is not None
self.fp32_flat_grad[start:end].add_(grad.flatten())
return grad

def _unflatten_params_as_views(self) -> None:
Expand Down Expand Up @@ -421,7 +421,7 @@ def _unflatten_params_as_views(self) -> None:
setattr(m, n, p) # This will set as plain attr
if self.optimize_backward_concat:
# Register post backward hook to accumulate the gradients
# in self.fp32_grads
# in self.fp32_flat_grad
param_end = param_start + torch.numel(p)
p.register_hook(
functools.partial(
Expand All @@ -433,10 +433,10 @@ def _unflatten_params_as_views(self) -> None:
param_start = param_end
param_views.append(p)

if self.optimize_backward_concat and self.fp32_grads is None:
if self.optimize_backward_concat and self.fp32_flat_grad is None:
# Allocate GPU memory for flattened fp32 grad accumulation
total_numels = sum([torch.numel(p) for p in param_views])
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device())
self.fp32_flat_grad = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device())


# Save param views for easy access if anyone still wants to access
Expand Down

0 comments on commit 759959d

Please sign in to comment.