diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 111460e478..654ec75763 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -107,7 +107,7 @@ def bark_wrapped_scaled_dot_product( is_causal = self.is_causal and query.shape[2] != 1 sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=self.dropout, is_causal=is_causal + query, key, value, attn_mask=None, dropout_p=self.dropout if self.training else 0., is_causal=is_causal ) return sdpa_result, None