diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 5e720d0cd7..1f873e4e71 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -277,7 +277,8 @@ def __init__( def _unmask_unattended_patched( - expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + expanded_mask: torch.Tensor, + min_dtype: float, ): return expanded_mask