Skip to content

Commit

Permalink
add loss_backward_retain_graph to __init__() (#856)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #856

Expose **retain_graph** kwarg in **loss.backward()** by adding a new argument **loss_backward_retain_graph** to **AutoUnit.__init__()**

Differential Revision: D58901158
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Jun 27, 2024
1 parent 12c5637 commit f185b87
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,10 @@ class AutoUnit(
activation_checkpoint_params: params for enabling activation checkpointing
training: if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
loss_backward_retain_graph: If ``None`` or ``False``, the graph used to compute
the grads will be freed during loss backward pass. Note that in nearly all cases setting
this option to True is not needed and often can be worked around
in a much more efficient way.
Note:
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
Expand Down Expand Up @@ -463,6 +467,7 @@ def __init__(
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
training: bool = True,
enable_compiled_autograd: bool = False,
loss_backward_retain_graph: Optional[bool] = None,
) -> None:
super().__init__(
module=module,
Expand Down Expand Up @@ -526,6 +531,7 @@ def __init__(

self.enable_compiled_autograd = enable_compiled_autograd
self.training = training
self.loss_backward_retain_graph = loss_backward_retain_graph

self.optimizer: Optional[torch.optim.Optimizer] = None
self.lr_scheduler: Optional[TLRScheduler] = None
Expand Down Expand Up @@ -625,7 +631,7 @@ def maybe_enable_compiled_autograd(
with get_timing_context(
state, f"{self.__class__.__name__}.backward"
):
loss.backward()
loss.backward(retain_graph=self.loss_backward_retain_graph)

total_grad_norm = None
if should_update_weights:
Expand Down

0 comments on commit f185b87

Please sign in to comment.