Skip to content

Commit

Permalink
Support multi-query attention for encoder-decoder models (#1339)
Browse files Browse the repository at this point in the history
* Support multi-query attention for encoder-decoder models

* Fix default value when multiquery option is missing
  • Loading branch information
guillaumekln authored Jul 12, 2023
1 parent 662c421 commit dac2f62
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _get_model_spec_seq2seq(
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge],
multi_query_attention=getattr(opt, "multiquery", False),
)

model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "<s>")
Expand Down Expand Up @@ -115,6 +116,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
rms_norm=opt.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=True,
multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_decoder(
Expand Down
22 changes: 21 additions & 1 deletion python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
multi_query_attention: bool = False,
):
"""Initializes a Transformer encoder specification.
Expand All @@ -42,6 +43,7 @@ def __init__(
ffn_glu: Use gated linear units in the FFN layers as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
"""
self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
Expand All @@ -63,6 +65,7 @@ def __init__(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
num_heads_kv=1 if multi_query_attention else None,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -141,6 +144,12 @@ def __init__(
)
num_heads_kv = 1

if with_encoder_attention and num_heads_kv not in (None, 1, num_heads):
raise ValueError(
"num_heads_kv=%d is not supported in the cross-attention layers"
% num_heads_kv
)

self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
self.activation = np.dtype("int8").type(activation)
Expand Down Expand Up @@ -192,12 +201,14 @@ def __init__(
relative_attention_bias=False,
ffn_glu=False,
rms_norm=False,
num_heads_kv=None,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
relative_position=relative_position,
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

Expand Down Expand Up @@ -225,8 +236,13 @@ def __init__(
rotary_interleave=rotary_interleave,
num_heads_kv=num_heads_kv,
)

if with_encoder_attention:
self.attention = attention_spec.MultiHeadAttentionSpec(rms_norm=rms_norm)
self.attention = attention_spec.MultiHeadAttentionSpec(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
)

self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

if parallel_residual:
Expand Down Expand Up @@ -309,6 +325,7 @@ def from_config(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
multi_query_attention: bool = False,
):
"""Creates a Transformer model specification.
Expand All @@ -332,6 +349,7 @@ def from_config(
ffn_glu: Use gated linear units in the FFN layer as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
"""
if isinstance(num_layers, (list, tuple)):
num_encoder_layers, num_decoder_layers = num_layers
Expand All @@ -351,6 +369,7 @@ def from_config(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
multi_query_attention=multi_query_attention,
)

decoder = TransformerDecoderSpec(
Expand All @@ -366,6 +385,7 @@ def from_config(
alignment_heads=alignment_heads,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
multi_query_attention=multi_query_attention,
)

return cls(encoder, decoder)
Expand Down
20 changes: 17 additions & 3 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,15 @@ namespace ctranslate2 {

if (cached_keys == nullptr || cached_keys->empty()) {
_linear[1](values, fused_proj);
split_heads(fused_proj, 2 * _num_heads, values_padder);
ops::Split(1)(fused_proj, keys_proj, values_proj);

if (_num_heads_kv == 1) {
if (values_padder)
values_padder->add_padding(fused_proj);
ops::Split(2, {_d_head, _d_head})(fused_proj, keys_proj, values_proj);
} else {
split_heads(fused_proj, 2 * _num_heads, values_padder);
ops::Split(1)(fused_proj, keys_proj, values_proj);
}

if (cached_keys != nullptr) {
*cached_keys = std::move(keys_proj);
Expand All @@ -448,7 +455,14 @@ namespace ctranslate2 {

if (queries_proj.dim(1) == 1 && cached_keys)
beam_size = queries_proj.dim(0) / cached_keys->dim(0);
split_heads(queries_proj, _num_heads, queries_padder, beam_size);

if (_num_heads_kv == 1) {
if (queries_padder)
queries_padder->add_padding(queries_proj);
queries_proj.reshape({queries_proj.dim(0) / beam_size, -1, _d_head});
} else {
split_heads(queries_proj, _num_heads, queries_padder, beam_size);
}

} else {

Expand Down

0 comments on commit dac2f62

Please sign in to comment.