From 5ebcb6f947d9cb4c81561f06f2e418f10bbb2f26 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:26:52 +0200 Subject: [PATCH] Update dropout in bark self attention Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/bettertransformer/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 111460e478..654ec75763 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -107,7 +107,7 @@ def bark_wrapped_scaled_dot_product( is_causal = self.is_causal and query.shape[2] != 1 sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=self.dropout, is_causal=is_causal + query, key, value, attn_mask=None, dropout_p=self.dropout if self.training else 0., is_causal=is_causal ) return sdpa_result, None