Skip to content

Commit

Permalink
add loss_backward_retain_graph to __init__()
Browse files Browse the repository at this point in the history
Summary:
Mask2Former (M2F) Executorch QAT model has its 5 top-level submodules prepared separately (https://fburl.com/code/44qk8qu3).
This is because the model graph during a) QAT training b) QAT evaluation c) ET model export time are different.
- We empirically find to train such ET QAT model, we need to turn on **loss.backward(retain_graph=True)** in train step. Otherwise, the training step will fail as in P1447579952.
- Thus, we add a new **loss_backward_retain_graph** to AutoUnit.__init__() to allow the user to have control on **retain_graph** kwargs.
- Note this change is back-compatible.

Differential Revision: D58901158
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Jun 26, 2024
1 parent 12c5637 commit 46b0a55
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def __init__(
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
training: bool = True,
enable_compiled_autograd: bool = False,
loss_backward_retain_graph: bool = False,
) -> None:
super().__init__(
module=module,
Expand Down Expand Up @@ -526,6 +527,7 @@ def __init__(

self.enable_compiled_autograd = enable_compiled_autograd
self.training = training
self.loss_backward_retain_graph = loss_backward_retain_graph

self.optimizer: Optional[torch.optim.Optimizer] = None
self.lr_scheduler: Optional[TLRScheduler] = None
Expand Down Expand Up @@ -625,7 +627,7 @@ def maybe_enable_compiled_autograd(
with get_timing_context(
state, f"{self.__class__.__name__}.backward"
):
loss.backward()
loss.backward(retain_graph=self.loss_backward_retain_graph)

total_grad_norm = None
if should_update_weights:
Expand Down

0 comments on commit 46b0a55

Please sign in to comment.