From 4b5abe2541be244f1efe069c41ce902606073ea5 Mon Sep 17 00:00:00 2001 From: Chris Cai Date: Wed, 1 May 2024 18:17:48 -0700 Subject: [PATCH] use new field to accumulate per-parameter grads in fp32 and copy into flatten_parameter.unsharded_main_grad in last microbatch backward() --- .../fully_sharded_data_parallel.py | 17 +++++++---- fairscale/nn/misc/flatten_params_wrapper.py | 29 ++++++++++++++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 48c1bd0ff..27c39f66a 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1717,11 +1717,18 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self._use_fp32_param_shard([param]) if self.fp32_reduce_scatter: - if getattr(param, "unsharded_main_grad", None) is None: - param.unsharded_main_grad = param.grad.to(torch.float32) - else: - param.unsharded_main_grad.add_(param.grad.data) - + #logger.info(f"CHRISLOG:{param.unsharded_main_grad.size()=}") + # logger.info(f"CHRISLOG:{len(self._fsdp_wrapped_module.fp32_grads)=}") + # grad_sizes = [grad.size() for grad in self._fsdp_wrapped_module.fp32_grads] + # logger.info(f"CHRISLOG:{grad_sizes=}") + + new_unsharded_main_grad_in_fp32 = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads]) + logger.info(f"CHRISLOG: assigning new unsharded_main_grad with size {new_unsharded_main_grad_in_fp32.size()}, type:{new_unsharded_main_grad_in_fp32.dtype}, original grad size {param.grad.size()}") + # if getattr(param, "unsharded_main_grad", None) is None: + # param.unsharded_main_grad = param.grad.to(torch.float32) + # else: + # param.unsharded_main_grad.add_(param.grad.data) + param.unsharded_main_grad = new_unsharded_main_grad_in_fp32 param.grad = None if not self._require_backward_grad_sync: diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 30b88360d..613869364 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -92,6 +92,9 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te raise ValueError( f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}" ) + logger.info(f"CHRISLOG: {data.numel()=}") + logger.info(f"CHRISLOG: {self._param_numels=}") + logger.info(f"CHRISLOG: {self._param_shapes=}") return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes)) def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]: @@ -164,6 +167,7 @@ def __init__( self._fpw_module = module self.is_flattened = False self._require_backward_grad_sync = True + self.fp32_grads = [] # Handle param_list being None. if param_list is None: @@ -367,6 +371,19 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No delattr(self, n) self.flat_params = [] + + def _hook( + self, + grad, + param_index, + ): + logger.info(f"CHRISLOG: before post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}") + if self.fp32_grads[param_index] is None: + self.fp32_grads[param_index] = grad.to(torch.float32) + else: + self.fp32_grads[param_index].add_(grad.data) + logger.info(f"CHRISLOG: after post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}") + def _unflatten_params_as_views(self) -> None: """Unlike ``_unflatten_params``, this function unflatten into views and keep self.flat_param unchanged. @@ -385,8 +402,18 @@ def _unflatten_params_as_views(self) -> None: for (_, m, n), p in zip(self._param_infos, ps): setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr - # logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}") + #logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}, {p.grad=}") + + import functools + p.register_hook( + functools.partial( + self._hook, + param_index=len(param_views) - 1 + ) + ) param_views.append(p) + if len(self.fp32_grads) == 0: + self.fp32_grads = [None] * len(param_views) # Save param views for easy access if anyone still wants to access # parameters of the module.