Skip to content

Commit

Permalink
use new field to accumulate per-parameter grads in fp32 and copy into…
Browse files Browse the repository at this point in the history
… flatten_parameter.unsharded_main_grad in last microbatch backward()
  • Loading branch information
chrisxcai committed May 2, 2024
1 parent 3429f33 commit 4b5abe2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
17 changes: 12 additions & 5 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,11 +1717,18 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.unsharded_main_grad.add_(param.grad.data)

#logger.info(f"CHRISLOG:{param.unsharded_main_grad.size()=}")
# logger.info(f"CHRISLOG:{len(self._fsdp_wrapped_module.fp32_grads)=}")
# grad_sizes = [grad.size() for grad in self._fsdp_wrapped_module.fp32_grads]
# logger.info(f"CHRISLOG:{grad_sizes=}")

new_unsharded_main_grad_in_fp32 = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
logger.info(f"CHRISLOG: assigning new unsharded_main_grad with size {new_unsharded_main_grad_in_fp32.size()}, type:{new_unsharded_main_grad_in_fp32.dtype}, original grad size {param.grad.size()}")
# if getattr(param, "unsharded_main_grad", None) is None:
# param.unsharded_main_grad = param.grad.to(torch.float32)
# else:
# param.unsharded_main_grad.add_(param.grad.data)
param.unsharded_main_grad = new_unsharded_main_grad_in_fp32
param.grad = None

if not self._require_backward_grad_sync:
Expand Down
29 changes: 28 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
logger.info(f"CHRISLOG: {data.numel()=}")
logger.info(f"CHRISLOG: {self._param_numels=}")
logger.info(f"CHRISLOG: {self._param_shapes=}")
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))

def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
Expand Down Expand Up @@ -164,6 +167,7 @@ def __init__(
self._fpw_module = module
self.is_flattened = False
self._require_backward_grad_sync = True
self.fp32_grads = []

# Handle param_list being None.
if param_list is None:
Expand Down Expand Up @@ -367,6 +371,19 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
delattr(self, n)
self.flat_params = []


def _hook(
self,
grad,
param_index,
):
logger.info(f"CHRISLOG: before post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}")
if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)
else:
self.fp32_grads[param_index].add_(grad.data)
logger.info(f"CHRISLOG: after post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}")

def _unflatten_params_as_views(self) -> None:
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
Expand All @@ -385,8 +402,18 @@ def _unflatten_params_as_views(self) -> None:
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
# logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}")
#logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}, {p.grad=}")

import functools
p.register_hook(
functools.partial(
self._hook,
param_index=len(param_views) - 1
)
)
param_views.append(p)
if len(self.fp32_grads) == 0:
self.fp32_grads = [None] * len(param_views)

# Save param views for easy access if anyone still wants to access
# parameters of the module.
Expand Down

0 comments on commit 4b5abe2

Please sign in to comment.