Skip to content

Commit

Permalink
logging
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed May 6, 2024
1 parent d2a88b7 commit 901fb86
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
# Licensed under the MIT License.

from contextlib import contextmanager
import functools
from itertools import chain
from re import split
import tempfile
import typing
from typing import (
Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
self._param_infos: List[Tuple[str, nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []

def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
def get_param_views(self, require_backward_grad_sync, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
Expand All @@ -95,7 +97,34 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te
# logger.info(f"CHRISLOG: {data.numel()=}")
# logger.info(f"CHRISLOG: {self._param_numels=}")
# logger.info(f"CHRISLOG: {self._param_shapes=}")
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))

# logger.info(f"CHRISLOG: {data.is_leaf=}, {data.grad_fn=}")

# def post_accumulate_grad_hook(
# param
# ):
# logger.info(f"CHRISLOG: cleaning up {param.grad=}")
# param.grad = None

# data.register_post_accumulate_grad_hook(
# functools.partial(
# post_accumulate_grad_hook
# )
# )
# logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on data")

split_outputs = data.split(self._param_numels)
# for split_output in split_outputs:
# logger.info(f"CHRISLOG: {require_backward_grad_sync=} {split_output.is_leaf=}, {split_output.grad_fn=}, {split_output.grad=}") #
# if not require_backward_grad_sync:
# split_output.register_hook(
# functools.partial(
# post_accumulate_grad_hook
# )
# )
# logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on split_output")

return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes))

def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
Expand Down Expand Up @@ -382,6 +411,7 @@ def _hook(
self.fp32_grads[param_index] = grad.to(torch.float32)
else:
self.fp32_grads[param_index].add_(grad.data)

#logger.info(f"CHRISLOG: after post-backward hook, self.fp32_grads[param_index] is None: {self.fp32_grads[param_index] is None}")

def _unflatten_params_as_views(self) -> None:
Expand All @@ -392,18 +422,17 @@ def _unflatten_params_as_views(self) -> None:
# logger.info(f"CHRISLOG: {self._require_backward_grad_sync=}")
if self._require_backward_grad_sync:
#logger.info("CHRISLOG: calling self.get_param_views() without torch.no_grad()")
ps = self.get_param_views()
ps = self.get_param_views(require_backward_grad_sync=self._require_backward_grad_sync)
else:
with torch.no_grad():
#logger.info("CHRISLOG: calling self.get_param_views() with torch.no_grad()")
ps = self.get_param_views()
ps = self.get_param_views(require_backward_grad_sync=self._require_backward_grad_sync)

param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
#logger.info(f"CHRISLOG: {n=}, {p.requires_grad=}, {p.grad_fn=}, {p.grad=}")

import functools
p.register_hook(
functools.partial(
Expand Down Expand Up @@ -528,7 +557,7 @@ def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs)

def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
def get_param_views(self, require_backward_grad_sync, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
"""Used to get a generator over all views from a list of external data list."""
params = self.flat_params
if external_data_list is None:
Expand All @@ -539,7 +568,7 @@ def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] =

gens = []
for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data))
gens.append(p.get_param_views(require_backward_grad_sync, data))

return chain(*gens)

Expand Down

0 comments on commit 901fb86

Please sign in to comment.