diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index bc1209bbb..30b88360d 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -372,7 +372,7 @@ def _unflatten_params_as_views(self) -> None: self.flat_param unchanged. """ assert self.is_flattened - #logger.info(f"CHRISLOG: {self._require_backward_grad_sync=}") + # logger.info(f"CHRISLOG: {self._require_backward_grad_sync=}") if self._require_backward_grad_sync: #logger.info("CHRISLOG: calling self.get_param_views() without torch.no_grad()") ps = self.get_param_views() @@ -385,7 +385,7 @@ def _unflatten_params_as_views(self) -> None: for (_, m, n), p in zip(self._param_infos, ps): setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr - #logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}") + # logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}") param_views.append(p) # Save param views for easy access if anyone still wants to access