Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ivy, torch-frontend): Extends ivy.multi_head_attention and shortens torch_frontend.multi_head_attention_forward #23131

Merged
merged 84 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
e28092f
rough refactor
AnnaTz Sep 6, 2023
3fabf90
fixed the call of ivy.multi_head_attention
AnnaTz Sep 7, 2023
6286431
refactored/extended tests
AnnaTz Sep 7, 2023
2dd2478
added the new args to the instances
AnnaTz Sep 7, 2023
2f88088
corrections in the compositional implementation
AnnaTz Sep 7, 2023
4d979ee
fixed bug in dtype sampling
AnnaTz Sep 7, 2023
236377b
Merge branch 'main' into mha
AnnaTz Sep 8, 2023
116c891
Merge branch 'main' into mha
AnnaTz Sep 13, 2023
c88852d
corrected pre_embed_dim vs embed_dim
AnnaTz Sep 13, 2023
b944063
avoid overflows
AnnaTz Sep 13, 2023
8794146
fix pre-embedding bug
AnnaTz Sep 13, 2023
b37e5bd
torch doesn't support self-attention
AnnaTz Sep 13, 2023
65ba5b7
Merge branch 'main' into mha
AnnaTz Sep 14, 2023
d637691
avoid overflow
AnnaTz Sep 14, 2023
a1076d0
corrected argument order
AnnaTz Sep 14, 2023
9d38d95
values won't be the same when training due to random dropout
AnnaTz Sep 14, 2023
db6a137
updated array ranges in the test
AnnaTz Sep 14, 2023
cd63554
fixed some bugs of embedding vs pre-embedding dims
AnnaTz Sep 14, 2023
03acf9d
forgot to remove code when moving it
AnnaTz Sep 14, 2023
9192fa0
wrong dim expanded in torch backend
AnnaTz Sep 15, 2023
71a4707
shortened the partial mixed handler check
AnnaTz Sep 15, 2023
9548879
native torch always returns a tuple
AnnaTz Sep 15, 2023
83c3671
fixed attention mask bugs
AnnaTz Sep 15, 2023
6cbb35d
have same_pre_embed_dim affect all cases in the strategy, so that the…
AnnaTz Sep 15, 2023
e8a36e0
fixed shape mismatches in return
AnnaTz Sep 15, 2023
7e449c5
ivy.where is inconsistent for non bool inputs
AnnaTz Sep 15, 2023
495b820
native torch doesn't return the correct weights when causal is applied
AnnaTz Sep 15, 2023
95c79d0
fixed the return dtypes of test strategy
AnnaTz Sep 15, 2023
8648f1e
native torch can't have None out_proj_weights
AnnaTz Sep 15, 2023
45f1608
fixed key_padding_mask
AnnaTz Sep 15, 2023
629fee1
fixed dimension order in torch backend
AnnaTz Sep 15, 2023
188404c
fixing bias_k/bias_v
AnnaTz Sep 15, 2023
ccb93a3
fixed mask dtypes
AnnaTz Sep 15, 2023
c96a5fb
fixed bias_k/bias_v
AnnaTz Sep 15, 2023
e41ded9
fixed zero attention
AnnaTz Sep 15, 2023
6791b4a
fixed return shape mismatches
AnnaTz Sep 15, 2023
b88276d
fixed static_k/static_v
AnnaTz Sep 15, 2023
5a46d04
Merge branch 'main' into mha
AnnaTz Sep 18, 2023
4e176c9
reshape expects integers
AnnaTz Sep 18, 2023
9ee4b00
leave a todo for the problematic cases and exclude them for now
AnnaTz Sep 18, 2023
1c5c7a2
fixed typo
AnnaTz Sep 18, 2023
ed96ca5
refactored torch helper to be able to re-use it in the frontend
AnnaTz Sep 18, 2023
4c6508a
completed the frontend function and test
AnnaTz Sep 18, 2023
c7e5279
fixed dtype inconsistency
AnnaTz Sep 19, 2023
123fcb6
removed faulty case from the test strategy
AnnaTz Sep 19, 2023
58aa54d
Merge branch 'main' into mha
AnnaTz Sep 19, 2023
51b0b54
fixed wrong dimension reference in frontend
AnnaTz Sep 19, 2023
6d9bb23
fixed scale bug in the frontend
AnnaTz Sep 19, 2023
0953097
fixed key_padding_mask in ivy function
AnnaTz Sep 19, 2023
a702baf
fixed kwarg errors in the torch backend
AnnaTz Sep 19, 2023
ba3af23
fixed a couple of left over bugs in the frontend
AnnaTz Sep 19, 2023
c9b55e0
didn't need to batch in the unbatched case in the torch backend
AnnaTz Sep 19, 2023
de3346c
similar fix for the unbatched case in the frontend
AnnaTz Sep 19, 2023
c680849
fixed list error; list of an array was separating the rows
AnnaTz Sep 19, 2023
4c25ac4
mask can be either bool or same type as query
AnnaTz Sep 19, 2023
edd2f39
fixed wrong dimension reference in the frontend
AnnaTz Sep 19, 2023
71fc93c
Merge branch 'main' into mha
AnnaTz Sep 20, 2023
c582622
fixed wrong dimension reference in the frontend
AnnaTz Sep 20, 2023
fc31f2e
force original mask dtypes
AnnaTz Sep 20, 2023
4b4e8a2
new detailed docstring
AnnaTz Sep 20, 2023
859bbdd
tolerance
AnnaTz Sep 20, 2023
a5ce167
finally fixed the dimension referencing issue in the static case
AnnaTz Sep 20, 2023
1e9daf2
previous commit continued
AnnaTz Sep 20, 2023
08d42a6
fixed bug of wrong dtypes passed to test_frontend_function
AnnaTz Sep 20, 2023
acb5abb
Merge branch 'main' into mha
AnnaTz Sep 21, 2023
baf0e21
fixed attention mask, numeric masks are added to the weights
AnnaTz Sep 21, 2023
7952491
emb_dim should be integer
AnnaTz Sep 21, 2023
2f00897
refactored the shaping of the matrix multiplications
AnnaTz Sep 21, 2023
bf56f48
similar attention mask fix for torch backend
AnnaTz Sep 21, 2023
9308c66
make scale positive
AnnaTz Sep 21, 2023
a2d2a28
mask refactor
AnnaTz Sep 25, 2023
7907596
key_padding_mask needs to be converted to float like attention_mask
AnnaTz Sep 25, 2023
c2e594f
corrected case of is_causal and need_weights in the frontend
AnnaTz Sep 25, 2023
827086e
fixed is_causal + more masking bugs
AnnaTz Sep 26, 2023
fda3ee1
🤖 Lint code
ivy-branch Sep 26, 2023
dca30d6
added argument batch_first
AnnaTz Sep 27, 2023
b350187
Merge branch 'main' into mha
AnnaTz Sep 27, 2023
72380d1
🤖 Lint code
ivy-branch Sep 27, 2023
803e2a4
test(backend handler): added tests for inplace update warnings. (#26067)
Madjid-CH Sep 27, 2023
8da76fe
fix(ivy): Added missing wrapper import
AnnaTz Sep 27, 2023
c2118d6
minor changes
AnnaTz Sep 27, 2023
47be090
refactored & fixed static_k/static_v
AnnaTz Sep 27, 2023
8f5ee33
renamed variable
AnnaTz Sep 27, 2023
c99bdbd
fixed bug in testing strategy
AnnaTz Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ivy/data_classes/array/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,12 @@ def multi_head_attention(
in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
is_causal: bool = False,
key_padding_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
add_zero_attn: bool = False,
return_attention_weights: bool = False,
average_attention_weights: bool = True,
dropout: float = 0.0,
Expand All @@ -430,6 +436,12 @@ def multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down
28 changes: 28 additions & 0 deletions ivy/data_classes/container/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,14 @@ def _static_multi_head_attention(
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
is_causal: Union[bool, ivy.Container] = False,
key_padding_mask: Optional[
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
add_zero_attn: Union[bool, ivy.Container] = False,
return_attention_weights: Union[bool, ivy.Container] = False,
average_attention_weights: Union[bool, ivy.Container] = True,
dropout: Union[float, ivy.Container] = 0.0,
Expand All @@ -1081,6 +1089,12 @@ def _static_multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down Expand Up @@ -1123,6 +1137,14 @@ def multi_head_attention(
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
is_causal: Union[bool, ivy.Container] = False,
key_padding_mask: Optional[
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
add_zero_attn: Union[bool, ivy.Container] = False,
return_attention_weights: Union[bool, ivy.Container] = False,
average_attention_weights: Union[bool, ivy.Container] = True,
dropout: Union[float, ivy.Container] = 0.0,
Expand All @@ -1148,6 +1170,12 @@ def multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/paddle/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ivy.func_wrapper import (
with_supported_dtypes,
with_unsupported_dtypes,
with_supported_device_and_dtypes,
)
import ivy.functional.backends.paddle as paddle_backend
from ivy.utils.einsum_parser import legalise_einsum_expr
Expand Down
120 changes: 119 additions & 1 deletion ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,129 @@

# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from . import backend_version
from ivy.functional.ivy.layers import _handle_padding, _deconv_length


@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "complex")},
backend_version,
)
def multi_head_attention(
query: torch.Tensor,
/,
*,
key: torch.Tensor = None,
value: torch.Tensor = None,
batch_first: bool = True,
num_heads: Optional[int] = 8,
scale: Optional[float] = None,
attention_mask: torch.Tensor = None,
in_proj_weights: torch.Tensor = None,
q_proj_weights: torch.Tensor = None,
k_proj_weights: torch.Tensor = None,
v_proj_weights: torch.Tensor = None,
out_proj_weights: torch.Tensor = None,
in_proj_bias: torch.Tensor = None,
out_proj_bias: torch.Tensor = None,
is_causal: Optional[bool] = False,
key_padding_mask: Optional[torch.Tensor] = None,
bias_k: Optional[torch.Tensor] = None,
bias_v: Optional[torch.Tensor] = None,
static_k: Optional[torch.Tensor] = None,
static_v: Optional[torch.Tensor] = None,
add_zero_attn: bool = False,
return_attention_weights: Optional[bool] = False,
average_attention_weights: Optional[bool] = True,
dropout: Optional[float] = 0.0,
training: Optional[bool] = False,
out: torch.Tensor = None,
) -> torch.Tensor:
if key is None and value is None:
key = value = query
emb_dim = _get_embed_dim(
in_proj_weights,
q_proj_weights,
k_proj_weights,
v_proj_weights,
query,
)[1]
num_dims = query.ndim
if num_dims == 3 and batch_first:
query, key, value = [torch.swapaxes(x, 0, 1) for x in [query, key, value]]
ret = torch.nn.functional.multi_head_attention_forward(
query,
key,
value,
emb_dim,
num_heads,
in_proj_weights,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout,
out_proj_weights,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=return_attention_weights,
attn_mask=attention_mask,
use_separate_proj_weight=not ivy.exists(in_proj_weights),
q_proj_weight=q_proj_weights,
k_proj_weight=k_proj_weights,
v_proj_weight=v_proj_weights,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attention_weights,
is_causal=is_causal,
)
ret = list(ret) if isinstance(ret, tuple) else [ret]
if num_dims == 3 and batch_first:
ret[0] = ret[0].swapaxes(0, 1)
if return_attention_weights:
return tuple(ret)
return ret[0]


multi_head_attention.partial_mixed_handler = (
lambda *args, scale=None, out_proj_weights=None, is_causal=False, attention_mask=None, return_attention_weights=False, in_proj_weights=None, q_proj_weights=None, k_proj_weights=None, v_proj_weights=None, **kwargs: not ivy.exists(
scale
)
and ivy.exists(out_proj_weights)
and (not is_causal or ivy.exists(attention_mask))
and (not is_causal or not return_attention_weights)
and (
ivy.exists(in_proj_weights)
or all(
[ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]
)
)
and len(
set(
_get_embed_dim(
in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, args[0]
)
)
)
== 1
)


def _get_embed_dim(
in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, query
):
pre_embed_dim = query.shape[-1]
if ivy.exists(in_proj_weights):
embed_dim = in_proj_weights.shape[0] / 3
elif all([ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]):
embed_dim = q_proj_weights.shape[0]
else:
embed_dim = None
return pre_embed_dim, embed_dim


@with_unsupported_dtypes(
{"2.0.1 and below": ("float16", "bfloat16", "complex")},
backend_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,143 +266,36 @@ def multi_head_attention_forward(
average_attn_weights=True,
is_causal=False,
):
# q/k/v shape: (seq_len, batch_size, embed_dim)
seq_len, batch_size, embed_dim = query.shape
embed_dim = query.shape[-1]
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert key.shape == value.shape

head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim needs to be divisible by heads"
scale = ivy.sqrt(head_dim)

if use_separate_proj_weight:
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's"
f" {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"

if is_causal and key_padding_mask is None and not need_weights:
mask = ivy.tril(ivy.ones((seq_len, seq_len), dtype=query.dtype), k=0)
attn_mask = ivy.zeros((seq_len, seq_len), dtype=query.dtype)
attn_mask = ivy.where(mask == 0.0, float("-inf"), 0)

if in_proj_bias is None:
q_bias, k_bias, v_bias = None, None, None
else:
q_bias, k_bias, v_bias = ivy.split(in_proj_bias, num_or_size_splits=3)

if not use_separate_proj_weight:
q_proj_weight, k_proj_weight, v_proj_weight = ivy.split(
in_proj_weight, num_or_size_splits=3
)

q = ivy.linear(query, q_proj_weight, bias=q_bias)
k = ivy.linear(key, k_proj_weight, bias=k_bias)
v = ivy.linear(value, v_proj_weight, bias=v_bias)

if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = ivy.concat([k, ivy.tile(bias_k, (1, batch_size, 1))])
v = ivy.concat([v, ivy.tile(bias_v, (1, batch_size, 1))])
if attn_mask is not None:
attn_mask = ivy.concat(
[attn_mask, ivy.zeros((attn_mask.shape[0], 1), dtype=attn_mask.dtype)],
axis=1,
)
if key_padding_mask is not None:
key_padding_mask = ivy.concat(
[
key_padding_mask,
ivy.zeros(
(key_padding_mask.shape[0], 1), dtype=key_padding_mask.dtype
).bool(),
],
axis=1,
)

q = ivy.swapaxes(q.reshape((q.shape[0], batch_size * num_heads, head_dim)), 0, 1)

if static_k is None:
k = ivy.swapaxes(
k.reshape((k.shape[0], batch_size * num_heads, head_dim)), 0, 1
)
else:
assert static_k.shape[0] == batch_size * num_heads, (
f"expecting static_k.shape[0] of {batch_size * num_heads}, but got"
f" {static_k.shape[0]}"
)
assert (
static_k.shape[2] == head_dim
), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
k = static_k

if static_v is None:
v = ivy.swapaxes(
v.reshape((v.shape[0], batch_size * num_heads, head_dim)), 0, 1
)
else:
assert static_v.shape[0] == batch_size * num_heads, (
f"expecting static_v.shape[0] of {batch_size * num_heads}, but got"
f" {static_v.shape[0]}"
)
assert (
static_v.shape[2] == head_dim
), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
v = static_v

# TODO add_zero_attn doesn't work for all cases
# fix this and add test cases (by changing to add_zero_attn=st.booleans())
if add_zero_attn:
zero_attn_shape = (batch_size * num_heads, 1, head_dim)
k = ivy.concat([k, ivy.zeros(zero_attn_shape, dtype=k.dtype)], axis=1)
v = ivy.concat([v, ivy.zeros(zero_attn_shape, dtype=v.dtype)], axis=1)
if attn_mask is not None:
attn_mask = ivy.pad(attn_mask, [(0, 0), (0, 1)])
if key_padding_mask is not None:
key_padding_mask = ivy.pad(key_padding_mask, [(0, 0), (0, 1)])

src_len = k.shape[1]
attn_weights = ivy.matmul(q, ivy.swapaxes(k, 1, 2))
assert list(attn_weights.shape) == [batch_size * num_heads, seq_len, src_len]

attn_weights = attn_weights / scale

if attn_mask is not None:
attn_mask = ivy.expand_dims(attn_mask, axis=0)
attn_weights += attn_mask

if key_padding_mask is not None:
key_padding_mask = ivy.expand_dims(
ivy.expand_dims(key_padding_mask, axis=1), axis=2
)
attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len))
attn_weights = ivy.where(key_padding_mask < 0.0, float("-inf"), attn_weights)
attn_weights = attn_weights.reshape((batch_size * num_heads, seq_len, src_len))

attn_weights = ivy.softmax(attn_weights, axis=-1)
attn_weights = ivy.dropout(attn_weights, dropout_p, training=training)

attn_output = ivy.matmul(attn_weights, v)
assert list(attn_output.shape) == [batch_size * num_heads, seq_len, head_dim]
attn_output = ivy.swapaxes(attn_output, 0, 1).reshape(
(seq_len, batch_size, embed_dim)
return ivy.multi_head_attention(
query,
key=key,
value=value,
batch_first=False,
num_heads=num_heads,
attention_mask=attn_mask,
in_proj_weights=in_proj_weight if not use_separate_proj_weight else None,
q_proj_weights=q_proj_weight,
k_proj_weights=k_proj_weight,
v_proj_weights=v_proj_weight,
out_proj_weights=out_proj_weight,
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal and not (need_weights or key_padding_mask is not None),
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=need_weights,
average_attention_weights=average_attn_weights,
dropout=dropout_p,
training=training,
)
attn_output = ivy.linear(attn_output, out_proj_weight, bias=out_proj_bias)

if need_weights:
attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len))
if average_attn_weights:
attn_weights = ivy.sum(attn_weights, axis=1) / num_heads
return (attn_output, attn_weights)
else:
return (attn_output,)


@to_ivy_arrays_and_back
Expand Down
Loading
Loading