Skip to content

Commit

Permalink
Fix type annotation to use base GradScaler as return type (#697)
Browse files Browse the repository at this point in the history
Summary: This diff fixes a type annotation issue in the `auto_unit` and `utils.precision` modules of the `torchtnt` library. The `grad_scaler` attribute was previously annotated as `Optional[GradScaler]` which referred to the CUDA specific GradScaler. However, it should be the general `Optional[torch.amp.GradScaler]`, as `get_grad_scaler_from_precision` returns either the `ShardedGradScaler` or the CUDA `GradScaler`. This change fixes the pyre error caused by this misannotation.

Test Plan: pyre -l torchtnt

Differential Revision: D53316506

Pulled By: johnhenning

fbshipit-source-id: 3fe088760b6faee9cd0afc6b94396fa2460f95e3
  • Loading branch information
johnhenning authored and facebook-github-bot committed Feb 2, 2024
1 parent 521984f commit 5b34c1c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
activation_checkpoint_params=activation_checkpoint_params,
)

self.grad_scaler: Optional[GradScaler] = None
self.grad_scaler: Optional[torch.amp.GradScaler] = None
if self.precision:
self.grad_scaler = get_grad_scaler_from_precision(
self.precision,
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:

def get_grad_scaler_from_precision(
precision: torch.dtype, module: torch.nn.Module
) -> Optional[GradScaler]:
) -> Optional[torch.amp.GradScaler]:
"""
Returns the correct grad scaler to use based on the precision and whether
or not the model is FSDP.
Expand Down

0 comments on commit 5b34c1c

Please sign in to comment.