Skip to content

Commit

Permalink
Add cast input argument (#1175)
Browse files Browse the repository at this point in the history
Co-authored-by: Jie Wang <[email protected]>
  • Loading branch information
whbldhwj and whbldhwj authored Apr 5, 2024
1 parent 5faca97 commit 7bcbc80
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(
gradient_predivide_factor: Optional[float] = None,
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
cast_input: bool = True,
):
try:
import torch._C
Expand Down Expand Up @@ -420,6 +421,7 @@ def __init__(
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.cast_input = cast_input
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.move_params_to_cpu = move_params_to_cpu or cpu_offload
Expand Down Expand Up @@ -1431,7 +1433,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
if self._is_root and self.mixed_precision:
if self._is_root and self.mixed_precision and self.cast_input:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

if self not in self._fsdp_forward_ordering:
Expand Down

0 comments on commit 7bcbc80

Please sign in to comment.