Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-query attention for encoder-decoder models #1339

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _get_model_spec_seq2seq(
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge],
multi_query_attention=getattr(opt, "multiquery", False),
)

model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "<s>")
Expand Down Expand Up @@ -115,6 +116,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
rms_norm=opt.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=True,
multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_decoder(
Expand Down
22 changes: 21 additions & 1 deletion python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
multi_query_attention: bool = False,
):
"""Initializes a Transformer encoder specification.
Expand All @@ -42,6 +43,7 @@ def __init__(
ffn_glu: Use gated linear units in the FFN layers as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
"""
self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
Expand All @@ -63,6 +65,7 @@ def __init__(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
num_heads_kv=1 if multi_query_attention else None,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -141,6 +144,12 @@ def __init__(
)
num_heads_kv = 1

if with_encoder_attention and num_heads_kv not in (None, 1, num_heads):
raise ValueError(
"num_heads_kv=%d is not supported in the cross-attention layers"
% num_heads_kv
)

self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
self.activation = np.dtype("int8").type(activation)
Expand Down Expand Up @@ -192,12 +201,14 @@ def __init__(
relative_attention_bias=False,
ffn_glu=False,
rms_norm=False,
num_heads_kv=None,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
relative_position=relative_position,
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

Expand Down Expand Up @@ -225,8 +236,13 @@ def __init__(
rotary_interleave=rotary_interleave,
num_heads_kv=num_heads_kv,
)

if with_encoder_attention:
self.attention = attention_spec.MultiHeadAttentionSpec(rms_norm=rms_norm)
self.attention = attention_spec.MultiHeadAttentionSpec(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
)

self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

if parallel_residual:
Expand Down Expand Up @@ -309,6 +325,7 @@ def from_config(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
multi_query_attention: bool = False,
):
"""Creates a Transformer model specification.
Expand All @@ -332,6 +349,7 @@ def from_config(
ffn_glu: Use gated linear units in the FFN layer as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
"""
if isinstance(num_layers, (list, tuple)):
num_encoder_layers, num_decoder_layers = num_layers
Expand All @@ -351,6 +369,7 @@ def from_config(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
multi_query_attention=multi_query_attention,
)

decoder = TransformerDecoderSpec(
Expand All @@ -366,6 +385,7 @@ def from_config(
alignment_heads=alignment_heads,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
multi_query_attention=multi_query_attention,
)

return cls(encoder, decoder)
Expand Down
20 changes: 17 additions & 3 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,15 @@ namespace ctranslate2 {

if (cached_keys == nullptr || cached_keys->empty()) {
_linear[1](values, fused_proj);
split_heads(fused_proj, 2 * _num_heads, values_padder);
ops::Split(1)(fused_proj, keys_proj, values_proj);

if (_num_heads_kv == 1) {
if (values_padder)
values_padder->add_padding(fused_proj);
ops::Split(2, {_d_head, _d_head})(fused_proj, keys_proj, values_proj);
} else {
split_heads(fused_proj, 2 * _num_heads, values_padder);
ops::Split(1)(fused_proj, keys_proj, values_proj);
}

if (cached_keys != nullptr) {
*cached_keys = std::move(keys_proj);
Expand All @@ -435,7 +442,14 @@ namespace ctranslate2 {

if (queries_proj.dim(1) == 1 && cached_keys)
beam_size = queries_proj.dim(0) / cached_keys->dim(0);
split_heads(queries_proj, _num_heads, queries_padder, beam_size);

if (_num_heads_kv == 1) {
if (queries_padder)
queries_padder->add_padding(queries_proj);
queries_proj.reshape({queries_proj.dim(0) / beam_size, -1, _d_head});
} else {
split_heads(queries_proj, _num_heads, queries_padder, beam_size);
}

} else {

Expand Down