Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix type annotation to use base GradScaler as return type (#697)
Summary: This diff fixes a type annotation issue in the `auto_unit` and `utils.precision` modules of the `torchtnt` library. The `grad_scaler` attribute was previously annotated as `Optional[GradScaler]` which referred to the CUDA specific GradScaler. However, it should be the general `Optional[torch.amp.GradScaler]`, as `get_grad_scaler_from_precision` returns either the `ShardedGradScaler` or the CUDA `GradScaler`. This change fixes the pyre error caused by this misannotation. Test Plan: pyre -l torchtnt Differential Revision: D53316506 Pulled By: johnhenning fbshipit-source-id: 3fe088760b6faee9cd0afc6b94396fa2460f95e3
- Loading branch information