From 901fb86d2c6557de7858de4f714162970196bf20 Mon Sep 17 00:00:00 2001 From: Chris Cai Date: Sun, 5 May 2024 22:15:45 -0700 Subject: [PATCH] logging --- fairscale/nn/misc/flatten_params_wrapper.py | 43 +++++++++++++++++---- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index a7b3a09e1..0e87bd8be 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -7,7 +7,9 @@ # Licensed under the MIT License. from contextlib import contextmanager +import functools from itertools import chain +from re import split import tempfile import typing from typing import ( @@ -81,7 +83,7 @@ def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): self._param_infos: List[Tuple[str, nn.Module, str]] = [] self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = [] - def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: + def get_param_views(self, require_backward_grad_sync, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: """Return a generator of views that map to the original parameters.""" # Note, self.data could be sharded, so its numel is <= to the sum. assert self.data.numel() <= sum( @@ -95,7 +97,34 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te # logger.info(f"CHRISLOG: {data.numel()=}") # logger.info(f"CHRISLOG: {self._param_numels=}") # logger.info(f"CHRISLOG: {self._param_shapes=}") - return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes)) + + # logger.info(f"CHRISLOG: {data.is_leaf=}, {data.grad_fn=}") + + # def post_accumulate_grad_hook( + # param + # ): + # logger.info(f"CHRISLOG: cleaning up {param.grad=}") + # param.grad = None + + # data.register_post_accumulate_grad_hook( + # functools.partial( + # post_accumulate_grad_hook + # ) + # ) + # logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on data") + + split_outputs = data.split(self._param_numels) + # for split_output in split_outputs: + # logger.info(f"CHRISLOG: {require_backward_grad_sync=} {split_output.is_leaf=}, {split_output.grad_fn=}, {split_output.grad=}") # + # if not require_backward_grad_sync: + # split_output.register_hook( + # functools.partial( + # post_accumulate_grad_hook + # ) + # ) + # logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on split_output") + + return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes)) def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]: """Return tuple of (names, shapes, numels) metadata for this flat parameter.""" @@ -382,6 +411,7 @@ def _hook( self.fp32_grads[param_index] = grad.to(torch.float32) else: self.fp32_grads[param_index].add_(grad.data) + #logger.info(f"CHRISLOG: after post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}") def _unflatten_params_as_views(self) -> None: @@ -392,18 +422,17 @@ def _unflatten_params_as_views(self) -> None: # logger.info(f"CHRISLOG: {self._require_backward_grad_sync=}") if self._require_backward_grad_sync: #logger.info("CHRISLOG: calling self.get_param_views() without torch.no_grad()") - ps = self.get_param_views() + ps = self.get_param_views(require_backward_grad_sync=self._require_backward_grad_sync) else: with torch.no_grad(): #logger.info("CHRISLOG: calling self.get_param_views() with torch.no_grad()") - ps = self.get_param_views() + ps = self.get_param_views(require_backward_grad_sync=self._require_backward_grad_sync) param_views = [] for (_, m, n), p in zip(self._param_infos, ps): setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr #logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}, {p.grad=}") - import functools p.register_hook( functools.partial( @@ -528,7 +557,7 @@ def forward(self, *inputs: Any, **kwinputs: Any) -> Any: self._unflatten_params_as_views() return self.module(*inputs, **kwinputs) - def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]: + def get_param_views(self, require_backward_grad_sync, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]: """Used to get a generator over all views from a list of external data list.""" params = self.flat_params if external_data_list is None: @@ -539,7 +568,7 @@ def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = gens = [] for p, data in zip(params, external_data_list): - gens.append(p.get_param_views(data)) + gens.append(p.get_param_views(require_backward_grad_sync, data)) return chain(*gens)