From ad7aa1fe6f3b96e7967a86cae92dd1540207ab4d Mon Sep 17 00:00:00 2001 From: Chris Cai Date: Thu, 9 May 2024 00:34:04 -0700 Subject: [PATCH] logging --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 58f59b1f3..7fe9d569d 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1723,6 +1723,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # logger.info(f"CHRISLOG:{grad_sizes=}") new_unsharded_main_grad_in_fp32 = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads]) + baseline_grad = param.grad.to(torch.float32) + + + logger.info(f"CHRISLOG: baseline grad {baseline_grad=}, {baseline_grad.size()=}") + logger.info(f"CHRISLOG: new grad {new_unsharded_main_grad_in_fp32=}, {new_unsharded_main_grad_in_fp32.size()=}") + torch.allclose(baseline_grad, new_unsharded_main_grad_in_fp32, atol=0, rtol=0) + logger.info(f"CHRISLOG: baseline grad and new grad passed allclose check") + # logger.info(f"CHRISLOG: assigning new unsharded_main_grad with size {new_unsharded_main_grad_in_fp32.size()}, type:{new_unsharded_main_grad_in_fp32.dtype}, original grad size {param.grad.size()}") # if getattr(param, "unsharded_main_grad", None) is None: # param.unsharded_main_grad = param.grad.to(torch.float32)