diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index fb4e2c5c5..13f1e5577 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -375,6 +375,7 @@ def __init__( gradient_predivide_factor: Optional[float] = None, limit_all_gather_events: bool = False, limit_reduce_scatter_events: bool = False, + cast_input: bool = True, optimize_backward_concat: bool = False, ): try: @@ -426,6 +427,7 @@ def __init__( self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.disable_reshard_on_root = disable_reshard_on_root self.mixed_precision = mixed_precision + self.cast_input = cast_input self.fp32_reduce_scatter = fp32_reduce_scatter self.flatten_parameters = flatten_parameters self.move_params_to_cpu = move_params_to_cpu or cpu_offload @@ -1450,7 +1452,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # For root and mixed precision, we convert the input to FP16 (no_grad is needed for # the conversion). is_bf16 = self.compute_dtype == torch.bfloat16 - if self._is_root and self.mixed_precision: + if self._is_root and self.mixed_precision and self.cast_input: args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) if self not in self._fsdp_forward_ordering: