From dfe5ea0aa697c84204711d8e97616060338f792b Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:26:28 +0100 Subject: [PATCH] refactor(ivy, torch-frontend): Extends ivy.multi_head_attention and shortens torch_frontend.multi_head_attention_forward (#23131) --- ivy/data_classes/array/layers.py | 12 + ivy/data_classes/container/layers.py | 28 ++ ivy/functional/backends/torch/layers.py | 120 +++++++- .../non_linear_activation_functions.py | 159 ++-------- ivy/functional/ivy/layers.py | 216 ++++++++++---- .../test_non_linear_activation_functions.py | 263 ++++------------- .../test_functional/test_nn/test_layers.py | 273 +++++++++++++----- 7 files changed, 590 insertions(+), 481 deletions(-) diff --git a/ivy/data_classes/array/layers.py b/ivy/data_classes/array/layers.py index 550a225b7e798..e9169b2ee0cbd 100644 --- a/ivy/data_classes/array/layers.py +++ b/ivy/data_classes/array/layers.py @@ -409,6 +409,12 @@ def multi_head_attention( in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, is_causal: bool = False, + key_padding_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + bias_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + bias_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + static_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + static_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + add_zero_attn: bool = False, return_attention_weights: bool = False, average_attention_weights: bool = True, dropout: float = 0.0, @@ -430,6 +436,12 @@ def multi_head_attention( in_proj_bias=in_proj_bias, out_proj_bias=out_proj_bias, is_causal=is_causal, + key_padding_mask=key_padding_mask, + bias_k=bias_k, + bias_v=bias_v, + static_k=static_k, + static_v=static_v, + add_zero_attn=add_zero_attn, return_attention_weights=return_attention_weights, average_attention_weights=average_attention_weights, dropout=dropout, diff --git a/ivy/data_classes/container/layers.py b/ivy/data_classes/container/layers.py index b56ed5c84a9d6..d071b6fa7be89 100644 --- a/ivy/data_classes/container/layers.py +++ b/ivy/data_classes/container/layers.py @@ -1055,6 +1055,14 @@ def _static_multi_head_attention( Union[ivy.Array, ivy.NativeArray, ivy.Container] ] = None, is_causal: Union[bool, ivy.Container] = False, + key_padding_mask: Optional[ + Union[ivy.Array, ivy.NativeArray, ivy.Container] + ] = None, + bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + add_zero_attn: Union[bool, ivy.Container] = False, return_attention_weights: Union[bool, ivy.Container] = False, average_attention_weights: Union[bool, ivy.Container] = True, dropout: Union[float, ivy.Container] = 0.0, @@ -1081,6 +1089,12 @@ def _static_multi_head_attention( in_proj_bias=in_proj_bias, out_proj_bias=out_proj_bias, is_causal=is_causal, + key_padding_mask=key_padding_mask, + bias_k=bias_k, + bias_v=bias_v, + static_k=static_k, + static_v=static_v, + add_zero_attn=add_zero_attn, return_attention_weights=return_attention_weights, average_attention_weights=average_attention_weights, dropout=dropout, @@ -1123,6 +1137,14 @@ def multi_head_attention( Union[ivy.Array, ivy.NativeArray, ivy.Container] ] = None, is_causal: Union[bool, ivy.Container] = False, + key_padding_mask: Optional[ + Union[ivy.Array, ivy.NativeArray, ivy.Container] + ] = None, + bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + add_zero_attn: Union[bool, ivy.Container] = False, return_attention_weights: Union[bool, ivy.Container] = False, average_attention_weights: Union[bool, ivy.Container] = True, dropout: Union[float, ivy.Container] = 0.0, @@ -1148,6 +1170,12 @@ def multi_head_attention( in_proj_bias=in_proj_bias, out_proj_bias=out_proj_bias, is_causal=is_causal, + key_padding_mask=key_padding_mask, + bias_k=bias_k, + bias_v=bias_v, + static_k=static_k, + static_v=static_v, + add_zero_attn=add_zero_attn, return_attention_weights=return_attention_weights, average_attention_weights=average_attention_weights, dropout=dropout, diff --git a/ivy/functional/backends/torch/layers.py b/ivy/functional/backends/torch/layers.py index 8bb277fd2bb44..f447ebd3c151e 100644 --- a/ivy/functional/backends/torch/layers.py +++ b/ivy/functional/backends/torch/layers.py @@ -6,11 +6,129 @@ # local import ivy -from ivy.func_wrapper import with_unsupported_dtypes +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes from . import backend_version from ivy.functional.ivy.layers import _handle_padding, _deconv_length +@with_supported_dtypes( + {"2.0.1 and below": ("float32", "float64", "complex")}, + backend_version, +) +def multi_head_attention( + query: torch.Tensor, + /, + *, + key: torch.Tensor = None, + value: torch.Tensor = None, + batch_first: bool = True, + num_heads: Optional[int] = 8, + scale: Optional[float] = None, + attention_mask: torch.Tensor = None, + in_proj_weights: torch.Tensor = None, + q_proj_weights: torch.Tensor = None, + k_proj_weights: torch.Tensor = None, + v_proj_weights: torch.Tensor = None, + out_proj_weights: torch.Tensor = None, + in_proj_bias: torch.Tensor = None, + out_proj_bias: torch.Tensor = None, + is_causal: Optional[bool] = False, + key_padding_mask: Optional[torch.Tensor] = None, + bias_k: Optional[torch.Tensor] = None, + bias_v: Optional[torch.Tensor] = None, + static_k: Optional[torch.Tensor] = None, + static_v: Optional[torch.Tensor] = None, + add_zero_attn: bool = False, + return_attention_weights: Optional[bool] = False, + average_attention_weights: Optional[bool] = True, + dropout: Optional[float] = 0.0, + training: Optional[bool] = False, + out: torch.Tensor = None, +) -> torch.Tensor: + if key is None and value is None: + key = value = query + emb_dim = _get_embed_dim( + in_proj_weights, + q_proj_weights, + k_proj_weights, + v_proj_weights, + query, + )[1] + num_dims = query.ndim + if num_dims == 3 and batch_first: + query, key, value = [torch.swapaxes(x, 0, 1) for x in [query, key, value]] + ret = torch.nn.functional.multi_head_attention_forward( + query, + key, + value, + emb_dim, + num_heads, + in_proj_weights, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout, + out_proj_weights, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=return_attention_weights, + attn_mask=attention_mask, + use_separate_proj_weight=not ivy.exists(in_proj_weights), + q_proj_weight=q_proj_weights, + k_proj_weight=k_proj_weights, + v_proj_weight=v_proj_weights, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attention_weights, + is_causal=is_causal, + ) + ret = list(ret) if isinstance(ret, tuple) else [ret] + if num_dims == 3 and batch_first: + ret[0] = ret[0].swapaxes(0, 1) + if return_attention_weights: + return tuple(ret) + return ret[0] + + +multi_head_attention.partial_mixed_handler = ( + lambda *args, scale=None, out_proj_weights=None, is_causal=False, attention_mask=None, return_attention_weights=False, in_proj_weights=None, q_proj_weights=None, k_proj_weights=None, v_proj_weights=None, **kwargs: not ivy.exists( + scale + ) + and ivy.exists(out_proj_weights) + and (not is_causal or ivy.exists(attention_mask)) + and (not is_causal or not return_attention_weights) + and ( + ivy.exists(in_proj_weights) + or all( + [ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]] + ) + ) + and len( + set( + _get_embed_dim( + in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, args[0] + ) + ) + ) + == 1 +) + + +def _get_embed_dim( + in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, query +): + pre_embed_dim = query.shape[-1] + if ivy.exists(in_proj_weights): + embed_dim = in_proj_weights.shape[0] / 3 + elif all([ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]): + embed_dim = q_proj_weights.shape[0] + else: + embed_dim = None + return pre_embed_dim, embed_dim + + @with_unsupported_dtypes( {"2.0.1 and below": ("float16", "bfloat16", "complex")}, backend_version, diff --git a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py index 03d62ea8eacb8..317032daab47b 100644 --- a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py @@ -266,143 +266,36 @@ def multi_head_attention_forward( average_attn_weights=True, is_causal=False, ): - # q/k/v shape: (seq_len, batch_size, embed_dim) - seq_len, batch_size, embed_dim = query.shape + embed_dim = query.shape[-1] assert ( embed_dim == embed_dim_to_check ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" - assert key.shape == value.shape - - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, "embed_dim needs to be divisible by heads" - scale = ivy.sqrt(head_dim) - - if use_separate_proj_weight: - assert key.shape[:2] == value.shape[:2], ( - f"key's sequence and batch dims {key.shape[:2]} do not match value's" - f" {value.shape[:2]}" - ) - else: - assert ( - key.shape == value.shape - ), f"key shape {key.shape} does not match value shape {value.shape}" - - if is_causal and key_padding_mask is None and not need_weights: - mask = ivy.tril(ivy.ones((seq_len, seq_len), dtype=query.dtype), k=0) - attn_mask = ivy.zeros((seq_len, seq_len), dtype=query.dtype) - attn_mask = ivy.where(mask == 0.0, float("-inf"), 0) - - if in_proj_bias is None: - q_bias, k_bias, v_bias = None, None, None - else: - q_bias, k_bias, v_bias = ivy.split(in_proj_bias, num_or_size_splits=3) - - if not use_separate_proj_weight: - q_proj_weight, k_proj_weight, v_proj_weight = ivy.split( - in_proj_weight, num_or_size_splits=3 - ) - - q = ivy.linear(query, q_proj_weight, bias=q_bias) - k = ivy.linear(key, k_proj_weight, bias=k_bias) - v = ivy.linear(value, v_proj_weight, bias=v_bias) - - if bias_k is not None and bias_v is not None: - assert static_k is None, "bias cannot be added to static key." - assert static_v is None, "bias cannot be added to static value." - k = ivy.concat([k, ivy.tile(bias_k, (1, batch_size, 1))]) - v = ivy.concat([v, ivy.tile(bias_v, (1, batch_size, 1))]) - if attn_mask is not None: - attn_mask = ivy.concat( - [attn_mask, ivy.zeros((attn_mask.shape[0], 1), dtype=attn_mask.dtype)], - axis=1, - ) - if key_padding_mask is not None: - key_padding_mask = ivy.concat( - [ - key_padding_mask, - ivy.zeros( - (key_padding_mask.shape[0], 1), dtype=key_padding_mask.dtype - ).bool(), - ], - axis=1, - ) - - q = ivy.swapaxes(q.reshape((q.shape[0], batch_size * num_heads, head_dim)), 0, 1) - - if static_k is None: - k = ivy.swapaxes( - k.reshape((k.shape[0], batch_size * num_heads, head_dim)), 0, 1 - ) - else: - assert static_k.shape[0] == batch_size * num_heads, ( - f"expecting static_k.shape[0] of {batch_size * num_heads}, but got" - f" {static_k.shape[0]}" - ) - assert ( - static_k.shape[2] == head_dim - ), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}" - k = static_k - - if static_v is None: - v = ivy.swapaxes( - v.reshape((v.shape[0], batch_size * num_heads, head_dim)), 0, 1 - ) - else: - assert static_v.shape[0] == batch_size * num_heads, ( - f"expecting static_v.shape[0] of {batch_size * num_heads}, but got" - f" {static_v.shape[0]}" - ) - assert ( - static_v.shape[2] == head_dim - ), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}" - v = static_v - - # TODO add_zero_attn doesn't work for all cases - # fix this and add test cases (by changing to add_zero_attn=st.booleans()) - if add_zero_attn: - zero_attn_shape = (batch_size * num_heads, 1, head_dim) - k = ivy.concat([k, ivy.zeros(zero_attn_shape, dtype=k.dtype)], axis=1) - v = ivy.concat([v, ivy.zeros(zero_attn_shape, dtype=v.dtype)], axis=1) - if attn_mask is not None: - attn_mask = ivy.pad(attn_mask, [(0, 0), (0, 1)]) - if key_padding_mask is not None: - key_padding_mask = ivy.pad(key_padding_mask, [(0, 0), (0, 1)]) - - src_len = k.shape[1] - attn_weights = ivy.matmul(q, ivy.swapaxes(k, 1, 2)) - assert list(attn_weights.shape) == [batch_size * num_heads, seq_len, src_len] - - attn_weights = attn_weights / scale - - if attn_mask is not None: - attn_mask = ivy.expand_dims(attn_mask, axis=0) - attn_weights += attn_mask - - if key_padding_mask is not None: - key_padding_mask = ivy.expand_dims( - ivy.expand_dims(key_padding_mask, axis=1), axis=2 - ) - attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len)) - attn_weights = ivy.where(key_padding_mask < 0.0, float("-inf"), attn_weights) - attn_weights = attn_weights.reshape((batch_size * num_heads, seq_len, src_len)) - - attn_weights = ivy.softmax(attn_weights, axis=-1) - attn_weights = ivy.dropout(attn_weights, dropout_p, training=training) - - attn_output = ivy.matmul(attn_weights, v) - assert list(attn_output.shape) == [batch_size * num_heads, seq_len, head_dim] - attn_output = ivy.swapaxes(attn_output, 0, 1).reshape( - (seq_len, batch_size, embed_dim) + return ivy.multi_head_attention( + query, + key=key, + value=value, + batch_first=False, + num_heads=num_heads, + attention_mask=attn_mask, + in_proj_weights=in_proj_weight if not use_separate_proj_weight else None, + q_proj_weights=q_proj_weight, + k_proj_weights=k_proj_weight, + v_proj_weights=v_proj_weight, + out_proj_weights=out_proj_weight, + in_proj_bias=in_proj_bias, + out_proj_bias=out_proj_bias, + is_causal=is_causal and not (need_weights or key_padding_mask is not None), + key_padding_mask=key_padding_mask, + bias_k=bias_k, + bias_v=bias_v, + static_k=static_k, + static_v=static_v, + add_zero_attn=add_zero_attn, + return_attention_weights=need_weights, + average_attention_weights=average_attn_weights, + dropout=dropout_p, + training=training, ) - attn_output = ivy.linear(attn_output, out_proj_weight, bias=out_proj_bias) - - if need_weights: - attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len)) - if average_attn_weights: - attn_weights = ivy.sum(attn_weights, axis=1) / num_heads - return (attn_output, attn_weights) - else: - return (attn_output,) @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index e9b942da7ebab..2a428fbfd91a5 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -705,7 +705,7 @@ def scaled_dot_product_attention( @handle_exceptions @handle_nestable @handle_out_argument -# @handle_array_like_without_promotion +@handle_partial_mixed_function @inputs_to_ivy_arrays @handle_array_function def multi_head_attention( @@ -714,6 +714,7 @@ def multi_head_attention( *, key: Optional[Union[ivy.Array, ivy.NativeArray]] = None, value: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + batch_first: bool = True, num_heads: int = 8, scale: Optional[float] = None, attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None, @@ -725,6 +726,12 @@ def multi_head_attention( in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, is_causal: bool = False, + key_padding_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + bias_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + bias_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + static_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + static_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + add_zero_attn: bool = False, return_attention_weights: bool = False, average_attention_weights: bool = True, dropout: float = 0.0, @@ -743,50 +750,69 @@ def multi_head_attention( value_dim)`. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor. Finally, the - result tensor with the last dimension as value_dim can take an linear projection and + result tensor with the last dimension as value_dim can take a linear projection and return. Parameters ---------- query - query embeddings *[batch_shape,num_queries,query_dim]*. + The query embeddings. Shape: `(L, Q)` or `(N, L, Q)`, where L is the number of + queries, N is the batch size, Q is the query embedding dimension. key - key embeddings *[batch_shape,num_queries,key_dim]*. + The key embeddings. Shape: `(S, K)` or `(N, S, K)`, where S is the number of + keys, N is the batch size, K is the key embedding dimension. value - value embeddings *[batch_shape,num_queries,value_dim]*. + The value embeddings. Shape `(S, V)` or `(N, S, V)`, where S is the number of + keys, N is the batch size, V is the value embedding dimension. + batch_first + If False, `query`, `key` and `value` will have shapes `(L, N, Q)`, `(S, N, K)` + and `(S, N, V)` respectively (if batched). num_heads The number of attention heads to use. scale The value by which to scale the query-key similarity measure before softmax. attention_mask - The mask to apply to the query-key values. Default is ``None``. - *[batch_shape,num_queries,num_keys]*. + The mask to apply to the query-key values. Shape: `(L, S)` or + `(N*num_heads, L, S)`. in_proj_weights - The weights used to project query, key and value *[3*E, E]. + The weights used to project query, key and value. Shape: `(3*E, E')`, where E + is the new embedding dimension and E' is the input embedding dimension, i.e. + `E' = Q = K = V`. q_proj_weights - The weights used to project query if in_proj_weights is None *[new_E, E]. + The weights used to project query if `in_proj_weights` is None. Shape: `(E, Q)`. k_proj_weights - The weights used to project key if in_proj_weights is None *[new_E, E]. + The weights used to project key if `in_proj_weights` is None. Shape: `(E, K)`. v_proj_weights - The weights used to project value if in_proj_weights is None *[new_E, E]. + The weights used to project value if `in_proj_weights` is None. Shape: `(E, V)`. out_proj_weights - The weights used to project the output. + The weights used to project the attention output. Shape: `(O, E)`, where O is + the output embedding dimension. in_proj_bias - The bias used when projecting with query, key and value. + The bias used when projecting query, key and value. Shape: `(3*E,)`. out_proj_bias - The bias used when projecting the output. + The bias used when projecting the output. Shape: `(O,)`. is_causal - If True, Uses a causal attention mask and ignores provided attention_mask. + If True, use a causal attention mask and ignore the provided `attention_mask`. + key_padding_mask + A binary mask to apply to the key sequence. Shape: `(S,)` or `(N, S)`. + bias_k + An additional bias added to the key sequence. Shape: `(E,)`. + bias_v + An additional bias added to the value sequence. Shape: `(E,)`. + static_k + A static key to be used in the attention operators. Shape: `(N*num_heads, S, E//num_heads)`. + static_v + A static value to be used in the attention operators. Shape: `(N*num_heads, S, E//num_heads)`. + add_zero_attn + A boolean flag indicating whether to add a batch of zeros to key and value. return_attention_weights - If True, returns attention_weights alongside the output - as a tuple (output, attenion_weights). Defaults to `False`. + If True, return the attention weights alongside the attention output. average_attention_weights - If true, indicates that the returned ``attention_weights`` should be averaged - across heads. Otherwise, ``attention_weights`` are provided separately per head. - Note that this flag only has an effect when ``return_attention_weights=True``. - Default: ``True`` (i.e. average weights across heads) + If True, the returned attention weights will be averaged across heads. + Otherwise, the attention weights will be provided separately per head. + Note that this flag only has an effect when `return_attention_weights=True`. dropout - Specifies the dropout probablity, dropout is applied to attention_weights. + Specifies the dropout probability. Dropout is applied on the attention weights. training If True, dropout is used, otherwise dropout is not activated. out @@ -796,9 +822,11 @@ def multi_head_attention( Returns ------- ret - The output following application of multi-head attention. - *[batch_shape,num_queries,out_feat_dim]* if input is batched - otherwise *[num_queries, out_feat_dim] + The output following the application of multi-head attention. Either `output` + or `(output, attention_weights)`. `output` will have shape `(L, E)` if the + inputs were unbatched or `(N, L, E)` otherwise, and `attention_weights` will + have shape `(L, S)` or `(N, L, S)` respectively. If `batch_first` is False and + the inputs were batched, the `output` will have shape `(L, N, E)`. Both the description and the type hints above assumes an array input for simplicity, but this function is *nestable*, and therefore also accepts :class:`ivy.Container` @@ -814,8 +842,13 @@ def multi_head_attention( key = value = query if num_dims == 2: query, key, value = [ivy.expand_dims(x, axis=0) for x in [query, key, value]] + elif not batch_first: + query, key, value = [ivy.swapaxes(x, 0, 1) for x in [query, key, value]] + + # project query, key and value if ivy.exists(in_proj_weights): q, k, v = _in_projection(query, key, value, w=in_proj_weights, b=in_proj_bias) + emb_dim = int(in_proj_weights.shape[0] / 3) elif all([ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]): if ivy.exists(in_proj_bias): b_q, b_k, b_v = ivy.split(in_proj_bias, num_or_size_splits=3) @@ -826,61 +859,130 @@ def multi_head_attention( ivy.linear(key, k_proj_weights, bias=b_k), ivy.linear(value, v_proj_weights, bias=b_v), ) + emb_dim = q_proj_weights.shape[0] else: q, k, v = query, key, value - batch_size, q_seq_length, emb_dim = q.shape[0], q.shape[1], q.shape[-1] - k_seq_length = k.shape[1] + if ivy.exists(out_proj_weights): + emb_dim = out_proj_weights.shape[-1] + else: + emb_dim = q.shape[-1] + + num_batches, num_queries = query.shape[:2] ivy.assertions.check_true( emb_dim % num_heads == 0, "features must be divisible by number of heads" ) - dims_per_head = emb_dim // num_heads - # isolate heads - q = q.reshape((batch_size, q_seq_length, num_heads, dims_per_head)).permute_dims( - (0, 2, 1, 3) - ) - k = k.reshape((batch_size, k_seq_length, num_heads, dims_per_head)).permute_dims( - (0, 2, 3, 1) - ) - v = v.reshape((batch_size, k_seq_length, num_heads, dims_per_head)).permute_dims( - (0, 2, 1, 3) - ) - # perform bmm - attn_scores = ivy.matmul(q, k) - # scale - scale = 1 / (dims_per_head**0.5) if not scale else scale + head_dim = emb_dim // num_heads + + # apply extra bias + if bias_k is not None and bias_v is not None: + ivy.assertions.check_true( + not (ivy.exists(static_k) or ivy.exists(static_v)), + "bias cannot be added to static key or value", + ) + k = ivy.concat([k, ivy.tile(bias_k, (num_batches, 1, 1))], axis=1) + v = ivy.concat([v, ivy.tile(bias_v, (num_batches, 1, 1))], axis=1) + + num_keys = k.shape[1] + + # reshape q, k, v for efficient matrix multiplication + q = ivy.swapaxes(q.reshape((num_queries, num_batches * num_heads, head_dim)), 0, 1) + if static_k is None: + k = ivy.swapaxes(k.reshape((num_keys, num_batches * num_heads, head_dim)), 0, 1) + else: + k = static_k + if static_v is None: + v = ivy.swapaxes(v.reshape((num_keys, num_batches * num_heads, head_dim)), 0, 1) + else: + v = static_v + + # add extra batch of zeros to k, v + if add_zero_attn: + zero_attn_shape = (num_batches * num_heads, 1, head_dim) + k = ivy.concat([k, ivy.zeros(zero_attn_shape, dtype=k.dtype)], axis=1) + v = ivy.concat([v, ivy.zeros(zero_attn_shape, dtype=v.dtype)], axis=1) + num_keys = k.shape[1] + + # get attention scores + attn_scores = ivy.matmul(q, ivy.swapaxes(k, 1, 2)) + scale = 1 / (head_dim**0.5) if not scale else scale attn_scores *= scale - # apply attention mask - if ivy.exists(attention_mask) or is_causal: + + # mask the attention scores + if ivy.exists(attention_mask): + assert attention_mask.dtype in [query.dtype, ivy.bool], ( + "was expecting attention_mask of type bool or the same as the input's, but" + f" got {attention_mask.dtype}" + ) if is_causal: - # create causal mask - attention_mask = ivy.tril(ivy.ones((q_seq_length, k_seq_length))) - attention_mask = attention_mask.astype("bool") - attn_scores = ivy.where(attention_mask, attn_scores, -ivy.inf) - # perform softmax + mask = ivy.triu(ivy.ones((num_queries, num_keys)), k=1) + attention_mask = ivy.where(mask, float("-inf"), 0) + elif ivy.is_bool_dtype(attention_mask): + attention_mask = ivy.where(attention_mask, float("-inf"), 0) + if attention_mask.ndim == 2: + attention_mask = ivy.tile(attention_mask, (num_batches * num_heads, 1, 1)) + if key_padding_mask is not None: + assert ivy.is_bool_dtype(key_padding_mask), ( + "was expecting key_padding_mask of type bool, but got" + f" {key_padding_mask.dtype}" + ) + key_padding_mask = ivy.where(key_padding_mask, float("-inf"), 0) + if num_dims == 2: + key_padding_mask = ivy.expand_dims(key_padding_mask, axis=0) + key_padding_mask = ivy.tile( + key_padding_mask, (num_batches * num_heads, num_queries, 1) + ) + if attention_mask is None: + attention_mask = key_padding_mask + else: + attention_mask += key_padding_mask + if ivy.exists(attention_mask): + if bias_k is not None and bias_v is not None and not is_causal: + attention_mask = ivy.pad(attention_mask, [(0, 0), (0, 0), (0, 1)]) + if add_zero_attn and not is_causal: + attention_mask = ivy.pad(attention_mask, [(0, 0), (0, 0), (0, 1)]) + attn_scores += attention_mask.astype(query.dtype) + + # get attention weights attn_weights = ivy.softmax(attn_scores, axis=-1) - # perform dropout attn_weights = ivy.dropout(attn_weights, dropout, training=training) - # bmm with values + + # get attention output attention_out = ivy.matmul(attn_weights, v) - attention_out = attention_out.permute_dims((0, 2, 1, 3)).reshape( - (batch_size, q_seq_length, -1) + attention_out = ivy.swapaxes(attention_out, 0, 1).reshape( + (num_batches, num_queries, emb_dim) ) - # proj out if out_proj_weight exists if ivy.exists(out_proj_weights): attention_out = ivy.linear(attention_out, out_proj_weights, bias=out_proj_bias) - # if input was unbatched, unbatchify the output + if num_dims == 2: attention_out = attention_out.squeeze(axis=0) + elif not batch_first: + attention_out = attention_out.swapaxes(0, 1) if return_attention_weights: + attn_weights = attn_weights.reshape( + (num_batches, num_heads, num_queries, num_keys) + ) if average_attention_weights: attn_weights = attn_weights.mean(axis=1) - if num_dims == 2: - attn_weights = attn_weights.squeeze(axis=0) + if num_dims == 2: + attn_weights = attn_weights.squeeze(axis=0) return attention_out, attn_weights else: return attention_out +multi_head_attention.mixed_backend_wrappers = { + "to_add": ( + "handle_backend_invalid", + "handle_out_argument", + "inputs_to_native_arrays", + "outputs_to_ivy_arrays", + "handle_device_shifting", + ), + "to_skip": ("inputs_to_ivy_arrays", "handle_partial_mixed_function"), +} + + # Convolutions # diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py index 4c4614c137fc3..010ced1773e0f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py @@ -1,11 +1,12 @@ # global import ivy from hypothesis import assume, strategies as st -import random # local import ivy_tests.test_ivy.helpers as helpers +from ivy.functional.backends.torch.layers import _get_embed_dim from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_nn.test_layers import _mha_helper # --- Helpers --- # @@ -97,170 +98,6 @@ def _x_and_scaled_attention(draw, dtypes): return dtype, query, key, value, mask -@st.composite -def mha_forward_args(draw, dtypes): - dtype = draw(dtypes) - embed_dim = draw(helpers.ints(min_value=2, max_value=4)) - batch_size = draw(helpers.ints(min_value=1, max_value=2)) * 3 - seq_len = draw(helpers.ints(min_value=2, max_value=4)) - shape = ( - seq_len, - batch_size, - embed_dim, - ) - - heads = draw(helpers.ints(min_value=1, max_value=4)) - head_dim = embed_dim // heads - if head_dim * heads != embed_dim: - heads = 1 - head_dim = embed_dim - - if dtype[0] == "float32": - is_causal = False - else: - is_causal = draw(helpers.array_bools(size=1))[0] - - q = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - k = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - v = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - in_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim * 3, embed_dim), - ) - ) - in_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim * 3,), - ) - ) - - if random.randint(0, 1) == 0: - use_separate_proj_weight = True - q_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - k_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - v_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - else: - use_separate_proj_weight = False - q_proj_weight = None - k_proj_weight = None - v_proj_weight = None - - out_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - out_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim,), - ) - ) - bias_k = random.choice( - [ - draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim,), - ) - ), - None, - ] - ) - bias_v = bias_k - - if bias_k is None: - static_k = random.choice( - [ - draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(batch_size * heads, seq_len, head_dim), - ) - ), - None, - ] - ) - static_v = static_k - else: - static_k = None - static_v = None - - attn_mask = ivy.ones((seq_len, seq_len), dtype=dtype[0]) - key_padding_mask = random.choice( - [ - ivy.random_normal(shape=(seq_len, seq_len), dtype=dtype[0]) > 0, - None, - ] - ) - - return ( - dtype, - q, - k, - v, - heads, - use_separate_proj_weight, - embed_dim, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - q_proj_weight, - k_proj_weight, - v_proj_weight, - bias_k, - bias_v, - static_k, - static_v, - attn_mask, - key_padding_mask, - is_causal, - ) - - # --- Main --- # # ------------ # @@ -852,14 +689,11 @@ def test_torch_mish( # multi_head_attention_forward @handle_frontend_test( fn_tree="torch.nn.functional.multi_head_attention_forward", - dtype_mha_args=mha_forward_args( - dtypes=helpers.get_dtypes("valid"), + dtype_mha_args=_mha_helper(same_pre_embed_dim=True, batch_second=True).filter( + lambda args: args[10] is not None + and (not args[22] or args[5] is not None) + and len(set(_get_embed_dim(*args[6:10], args[1]))) == 1 ), - add_zero_attn=st.just(False), - dropout_p=st.sampled_from([0.0, 0.1, 0.2]), - training=st.booleans(), - need_weights=st.booleans(), - average_attn_weights=st.booleans(), test_with_out=st.just(False), ) def test_torch_multi_head_attention_forward( @@ -869,11 +703,6 @@ def test_torch_multi_head_attention_forward( frontend, test_flags, dtype_mha_args, - add_zero_attn, - dropout_p, - training, - need_weights, - average_attn_weights, backend_fw, ): ( @@ -882,57 +711,69 @@ def test_torch_multi_head_attention_forward( k, v, heads, - use_separate_proj_weight, - embed_dim, + attn_mask, in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, q_proj_weight, k_proj_weight, v_proj_weight, + out_proj_weight, + in_proj_bias, + out_proj_bias, + key_padding_mask, bias_k, bias_v, static_k, static_v, - attn_mask, - key_padding_mask, + _, + add_zero_attn, + dropout_p, + training, is_causal, + need_weights, + average_attn_weights, + batch_first, ) = dtype_mha_args - + if k is None and v is None: + k = v = q + # re-order the dtypes to match the order of the frontend arguments, not the order + # of ivy.multi_head_attention's arguments given by _mha_helper + kwargs = { + "query": q, + "key": k, + "value": v, + "embed_dim_to_check": q.shape[-1], + "num_heads": heads, + "in_proj_weight": in_proj_weight, + "in_proj_bias": in_proj_bias, + "bias_k": bias_k, + "bias_v": bias_v, + "add_zero_attn": add_zero_attn, + "dropout_p": dropout_p, + "out_proj_weight": out_proj_weight, + "out_proj_bias": out_proj_bias, + "training": training, + "key_padding_mask": key_padding_mask, + "need_weights": need_weights, + "attn_mask": attn_mask, + "use_separate_proj_weight": in_proj_weight is None, + "q_proj_weight": q_proj_weight, + "k_proj_weight": k_proj_weight, + "v_proj_weight": v_proj_weight, + "static_k": static_k, + "static_v": static_v, + "average_attn_weights": average_attn_weights, + "is_causal": is_causal, + } helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=[str(r.dtype) for r in kwargs.values() if ivy.is_array(r)], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + atol=1e-03, on_device=on_device, test_values=not training or dropout_p == 0.0, - query=q, - key=k, - value=v, - embed_dim_to_check=embed_dim, - num_heads=heads, - in_proj_weight=in_proj_weight, - in_proj_bias=in_proj_bias, - bias_k=bias_k, - bias_v=bias_v, - add_zero_attn=add_zero_attn, - dropout_p=dropout_p, - out_proj_weight=out_proj_weight, - out_proj_bias=out_proj_bias, - training=training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, - k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, - static_k=static_k, - static_v=static_v, - average_attn_weights=average_attn_weights, - is_causal=is_causal, + **kwargs, ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index cb1bd5af5df55..d68219a6323e3 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -68,56 +68,55 @@ def _dropout_helper(draw): @st.composite -def _mha_helper(draw): +def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False): _qkv_same_dim = draw(st.booleans()) _self_attention = draw(st.booleans()) + _same_pre_embed_dim = _self_attention or same_pre_embed_dim or draw(st.booleans()) + batch_first = draw(st.booleans()) and not batch_second num_heads = draw(helpers.ints(min_value=1, max_value=3)) _embed_dim = draw(helpers.ints(min_value=4, max_value=16)) * num_heads + _batch_dim = draw(st.sampled_from([(), (1,)])) + _num_batches = _batch_dim[0] if len(_batch_dim) else 1 + dtype = draw(helpers.get_dtypes("valid", full=False)) _num_queries = draw(helpers.ints(min_value=2, max_value=8)) _num_keys = draw(helpers.ints(min_value=2, max_value=8)) - _batch_dim = draw(st.sampled_from([(), (1,)])) - dtype = draw(helpers.get_dtypes("float", full=False, prune_function=False)) - in_proj_bias = None in_proj_weights = None q_proj_weights = None k_proj_weights = None v_proj_weights = None - _mask_shape = ( - _num_queries, - _num_queries if _self_attention and _qkv_same_dim else _num_keys, - ) - if _qkv_same_dim: - _pre_embed_dim = draw(helpers.ints(min_value=4, max_value=16)) - _q_shape = _batch_dim + (_num_queries, _pre_embed_dim) - _kv_shape = _batch_dim + (_num_keys, _pre_embed_dim) + if _qkv_same_dim: + if _same_pre_embed_dim: + _pre_embed_dim = _embed_dim + else: + _pre_embed_dim = draw(helpers.ints(min_value=4, max_value=16)) q = draw( helpers.array_values( - shape=_q_shape, + shape=(*_batch_dim, _num_queries, _pre_embed_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) ) k = draw( helpers.array_values( - shape=_kv_shape, + shape=(*_batch_dim, _num_keys, _pre_embed_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) if not _self_attention else st.none() ) v = draw( helpers.array_values( - shape=_kv_shape, + shape=(*_batch_dim, _num_keys, _pre_embed_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) if not _self_attention else st.none() @@ -126,102 +125,191 @@ def _mha_helper(draw): helpers.array_values( dtype=dtype[0], shape=(3 * _embed_dim, _pre_embed_dim), - min_value=0, + min_value=-10, max_value=10, ) - if _pre_embed_dim != _embed_dim + if not _same_pre_embed_dim or draw(st.booleans()) else st.none() ) else: - _q_dim = draw(helpers.ints(min_value=2, max_value=8)) + if not same_pre_embed_dim: + _q_dim = draw(helpers.ints(min_value=2, max_value=8)) + else: + _q_dim = _embed_dim _k_dim = draw(helpers.ints(min_value=2, max_value=8)) _v_dim = draw(helpers.ints(min_value=2, max_value=8)) - _q_shape = _batch_dim + (_num_queries, _q_dim) - _k_shape = _batch_dim + (_num_keys, _k_dim) - _v_shape = _batch_dim + (_num_keys, _v_dim) q = draw( helpers.array_values( - shape=_q_shape, + shape=(*_batch_dim, _num_queries, _q_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) ) k = draw( helpers.array_values( - shape=_k_shape, + shape=(*_batch_dim, _num_keys, _k_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) ) v = draw( helpers.array_values( - shape=_v_shape, + shape=(*_batch_dim, _num_keys, _v_dim), dtype=dtype[0], - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, ) ) q_proj_weights = draw( helpers.array_values( dtype=dtype[0], shape=(_embed_dim, _q_dim), - min_value=0, - max_value=2, + min_value=-5, + max_value=5, ) ) k_proj_weights = draw( helpers.array_values( dtype=dtype[0], shape=(_embed_dim, _k_dim), - min_value=0, - max_value=2, + min_value=-5, + max_value=5, ) ) v_proj_weights = draw( helpers.array_values( dtype=dtype[0], shape=(_embed_dim, _v_dim), - min_value=0, - max_value=2, + min_value=-5, + max_value=5, ) ) - in_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], shape=(3 * _embed_dim), min_value=0, max_value=10 + st.one_of( + helpers.array_values( + dtype=dtype[0], + shape=(3 * _embed_dim,), + min_value=-10, + max_value=10, + ), + st.none(), ) - | st.none() ) + _out_dim = draw(helpers.ints(min_value=4, max_value=16)) out_proj_weights = draw( - helpers.array_values( - dtype=dtype[0], - shape=(_out_dim, _embed_dim), - min_value=0, - max_value=2, + st.one_of( + helpers.array_values( + dtype=dtype[0], + shape=(_out_dim, _embed_dim), + min_value=-5, + max_value=5, + ), + st.none(), ) - | st.none() ) out_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], shape=(_out_dim), min_value=0, max_value=10 + st.one_of( + helpers.array_values( + dtype=dtype[0], + shape=(_out_dim,), + min_value=-10, + max_value=10, + ), + st.none(), + ) + ) + + if _self_attention and _qkv_same_dim: + _num_keys = _num_queries + _static_shape = (_num_batches * num_heads, _num_keys, int(_embed_dim // num_heads)) + static_k = draw( + st.one_of( + helpers.array_values( + shape=_static_shape, + dtype=dtype[0], + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, + ), + st.none(), + ) + ) + static_v = draw( + st.one_of( + helpers.array_values( + shape=_static_shape, + dtype=dtype[0], + max_value=1000, + min_value=-1000, + abs_smallest_val=1e-06, + ), + st.none(), ) - | st.none() ) + + _mask_shape = (_num_queries, _num_keys) + if len(_batch_dim) and draw(st.booleans()): + _mask_shape = (_num_batches * num_heads, *_mask_shape) attention_mask = draw( + st.one_of( + helpers.array_values( + dtype=draw(st.sampled_from(["bool", dtype[0]])), + allow_inf=True, + shape=_mask_shape, + ), + st.none(), + ) + ) + + key_padding_mask = draw( + st.one_of( + helpers.array_values( + dtype="bool", + shape=(*_batch_dim, _num_keys), + ), + st.none(), + ) + ) + + _extra_bias = ( + (not _qkv_same_dim or _pre_embed_dim == _embed_dim) + and static_k is None + and static_v is None + and draw(st.booleans()) + ) + bias_k = draw( helpers.array_values( - dtype="bool", - shape=_mask_shape, + dtype=dtype[0], shape=(_embed_dim,), min_value=-10, max_value=10 ) - | st.none() + if _extra_bias + else st.none() ) - return ( - dtype, + bias_v = draw( + helpers.array_values( + dtype=dtype[0], shape=(_embed_dim,), min_value=-10, max_value=10 + ) + if _extra_bias + else st.none() + ) + + scale = draw(st.one_of(st.floats(min_value=0.001), st.none())) + add_zero_attn = draw(st.booleans()) + dropout = draw(st.floats(min_value=0, max_value=0.99)) + training = draw(st.booleans()) + is_causal = draw(st.booleans()) + return_attention_weights = draw(st.booleans()) + average_attention_weights = draw(st.booleans()) + + if len(q.shape) == 3 and not batch_first: + q, k, v = [np.swapaxes(x, 0, 1) if x is not None else x for x in [q, k, v]] + + ret = ( q, k, v, @@ -234,7 +322,22 @@ def _mha_helper(draw): out_proj_weights, in_proj_bias, out_proj_bias, + key_padding_mask, + bias_k, + bias_v, + static_k, + static_v, + scale, + add_zero_attn, + dropout, + training, + is_causal, + return_attention_weights, + average_attention_weights, + batch_first, ) + ret_dtypes = [str(r.dtype) for r in ret if ivy.is_array(r)] + return ret_dtypes, *ret @st.composite @@ -1274,23 +1377,14 @@ def test_lstm_update(*, dtype_lstm, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.multi_head_attention", dtype_mha=_mha_helper(), - scale=st.one_of(st.floats(), st.none()), - dropout=st.floats(min_value=0, max_value=0.99), - training=st.just(False), # st.booleans(), disabled until proper testing is used - is_causal=st.booleans(), - return_attention_weights=st.booleans(), - average_attention_weights=st.booleans(), - ground_truth_backend="jax", + ground_truth_backend="numpy", + # ToDo: fix the gradients and the container methods + test_gradients=st.just(False), + container_flags=st.just([False]), ) def test_multi_head_attention( *, dtype_mha, - scale, - dropout, - training, - is_causal, - return_attention_weights, - average_attention_weights, test_flags, backend_fw, fn_name, @@ -1310,6 +1404,19 @@ def test_multi_head_attention( out_proj_weights, in_proj_bias, out_proj_bias, + key_padding_mask, + bias_k, + bias_v, + static_k, + static_v, + scale, + add_zero_attn, + dropout, + training, + is_causal, + return_attention_weights, + average_attention_weights, + batch_first, ) = dtype_mha helpers.test_function( input_dtypes=dtype, @@ -1317,11 +1424,13 @@ def test_multi_head_attention( backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, + test_values=(dropout == 0), atol_=1e-02, rtol_=1e-02, query=q, key=k, value=v, + batch_first=batch_first, num_heads=num_heads, scale=scale, attention_mask=attention_mask, @@ -1333,6 +1442,12 @@ def test_multi_head_attention( in_proj_bias=in_proj_bias, out_proj_bias=out_proj_bias, is_causal=is_causal, + key_padding_mask=key_padding_mask, + bias_k=bias_k, + bias_v=bias_v, + static_k=static_k, + static_v=static_v, + add_zero_attn=add_zero_attn, return_attention_weights=return_attention_weights, average_attention_weights=average_attention_weights, dropout=dropout,