Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add loss_backward_retain_graph to __init__()
Summary: Mask2Former (M2F) Executorch QAT model has its 5 top-level submodules prepared separately (https://fburl.com/code/44qk8qu3). This is because the model graph during a) QAT training b) QAT evaluation c) ET model export time are different. - We empirically find to train such ET QAT model, we need to turn on **loss.backward(retain_graph=True)** in train step. Otherwise, the training step will fail as in P1447579952. - Thus, we add a new **loss_backward_retain_graph** to AutoUnit.__init__() to allow the user to have control on **retain_graph** kwargs. - Note this change is back-compatible. Differential Revision: D58901158
- Loading branch information