Skip to content

Commit

Permalink
add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed Jun 9, 2024
1 parent 1fa3fb1 commit 54fe983
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,14 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
def _grad_accumulation_hook(
self,
grad,
#param_index,
start,
end,
):
"""
start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_grads
end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_grads
"""

assert self.fp32_grads is not None
self.fp32_grads[start:end].add_(grad.flatten())
return grad
Expand Down Expand Up @@ -430,6 +434,7 @@ def _unflatten_params_as_views(self) -> None:
param_views.append(p)

if self.optimize_backward_concat and self.fp32_grads is None:
# Allocate GPU memory for flattened fp32 grad accumulation
total_numels = sum([torch.numel(p) for p in param_views])
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device())

Expand Down

0 comments on commit 54fe983

Please sign in to comment.