Skip to content

Commit

Permalink
use torch.no_grad() to avoid calling cat() during FSDP backward excep…
Browse files Browse the repository at this point in the history
…t for last microbatch
  • Loading branch information
chrisxcai committed Apr 29, 2024
1 parent d0b506f commit d1102ce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
7 changes: 4 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,12 +1099,14 @@ def no_sync(self) -> Generator:
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
m._fsdp_wrapped_module._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag

@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
Expand Down Expand Up @@ -1458,7 +1460,6 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()

outputs = self.module(*args, **kwargs)

if self.reshard_after_forward:
Expand Down Expand Up @@ -1851,7 +1852,7 @@ def _wait_for_post_backward(self) -> None:
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
if any([p.requires_grad for p in self.params]) and self._fsdp_wrapped_module._require_backward_grad_sync:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
Expand Down Expand Up @@ -1928,7 +1929,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
if any([p.requires_grad for p in m.params]) and self._fsdp_wrapped_module._require_backward_grad_sync:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
Expand Down
13 changes: 12 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

from logging import getLogger
logger = getLogger()

class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
super().__init__()
self._fpw_module = module
self.is_flattened = False
self._require_backward_grad_sync = True

# Handle param_list being None.
if param_list is None:
Expand Down Expand Up @@ -369,7 +372,15 @@ def _unflatten_params_as_views(self) -> None:
self.flat_param unchanged.
"""
assert self.is_flattened
ps = self.get_param_views()
logger.info(f"CHRISLOG: {self._require_backward_grad_sync=}")
if self._require_backward_grad_sync:
logger.info("CHRISLOG: calling self.get_param_views() without torch.no_grad()")
ps = self.get_param_views()
else:
with torch.no_grad():
logger.info("CHRISLOG: calling self.get_param_views() with torch.no_grad()")
ps = self.get_param_views()

param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
Expand Down

0 comments on commit d1102ce

Please sign in to comment.