From dac2f6252745c3f3f9206a24b13fbd35983c76c5 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 12 Jul 2023 17:05:57 +0200 Subject: [PATCH] Support multi-query attention for encoder-decoder models (#1339) * Support multi-query attention for encoder-decoder models * Fix default value when multiquery option is missing --- python/ctranslate2/converters/opennmt_py.py | 2 ++ python/ctranslate2/specs/transformer_spec.py | 22 +++++++++++++++++++- src/layers/attention.cc | 20 +++++++++++++++--- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index 951b05f19..8143977e0 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -82,6 +82,7 @@ def _get_model_spec_seq2seq( alignment_heads=alignment_heads, num_source_embeddings=num_source_embeddings, embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge], + multi_query_attention=getattr(opt, "multiquery", False), ) model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "") @@ -115,6 +116,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd rms_norm=opt.layer_norm == "rms", rotary_dim=rotary_dim, rotary_interleave=True, + multi_query_attention=getattr(opt, "multiquery", False), ) set_transformer_decoder( diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index f1d6ec017..242854af2 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -22,6 +22,7 @@ def __init__( relative_attention_bias: bool = False, ffn_glu: bool = False, rms_norm: bool = False, + multi_query_attention: bool = False, ): """Initializes a Transformer encoder specification. @@ -42,6 +43,7 @@ def __init__( ffn_glu: Use gated linear units in the FFN layers as described in https://arxiv.org/abs/2002.05202. rms_norm: Use the root mean square layer normalization. + multi_query_attention: Use multi-query attention. """ self.num_heads = np.dtype("int16").type(num_heads) self.pre_norm = pre_norm @@ -63,6 +65,7 @@ def __init__( relative_attention_bias=relative_attention_bias, ffn_glu=ffn_glu, rms_norm=rms_norm, + num_heads_kv=1 if multi_query_attention else None, ) for _ in range(num_layers) ] @@ -141,6 +144,12 @@ def __init__( ) num_heads_kv = 1 + if with_encoder_attention and num_heads_kv not in (None, 1, num_heads): + raise ValueError( + "num_heads_kv=%d is not supported in the cross-attention layers" + % num_heads_kv + ) + self.num_heads = np.dtype("int16").type(num_heads) self.pre_norm = pre_norm self.activation = np.dtype("int8").type(activation) @@ -192,12 +201,14 @@ def __init__( relative_attention_bias=False, ffn_glu=False, rms_norm=False, + num_heads_kv=None, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, relative_position=relative_position, relative_attention_bias=relative_attention_bias, rms_norm=rms_norm, + num_heads_kv=num_heads_kv, ) self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) @@ -225,8 +236,13 @@ def __init__( rotary_interleave=rotary_interleave, num_heads_kv=num_heads_kv, ) + if with_encoder_attention: - self.attention = attention_spec.MultiHeadAttentionSpec(rms_norm=rms_norm) + self.attention = attention_spec.MultiHeadAttentionSpec( + rms_norm=rms_norm, + num_heads_kv=num_heads_kv, + ) + self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) if parallel_residual: @@ -309,6 +325,7 @@ def from_config( relative_attention_bias: bool = False, ffn_glu: bool = False, rms_norm: bool = False, + multi_query_attention: bool = False, ): """Creates a Transformer model specification. @@ -332,6 +349,7 @@ def from_config( ffn_glu: Use gated linear units in the FFN layer as described in https://arxiv.org/abs/2002.05202. rms_norm: Use the root mean square layer normalization. + multi_query_attention: Use multi-query attention. """ if isinstance(num_layers, (list, tuple)): num_encoder_layers, num_decoder_layers = num_layers @@ -351,6 +369,7 @@ def from_config( relative_attention_bias=relative_attention_bias, ffn_glu=ffn_glu, rms_norm=rms_norm, + multi_query_attention=multi_query_attention, ) decoder = TransformerDecoderSpec( @@ -366,6 +385,7 @@ def from_config( alignment_heads=alignment_heads, ffn_glu=ffn_glu, rms_norm=rms_norm, + multi_query_attention=multi_query_attention, ) return cls(encoder, decoder) diff --git a/src/layers/attention.cc b/src/layers/attention.cc index c06a69b4f..4b057b535 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -437,8 +437,15 @@ namespace ctranslate2 { if (cached_keys == nullptr || cached_keys->empty()) { _linear[1](values, fused_proj); - split_heads(fused_proj, 2 * _num_heads, values_padder); - ops::Split(1)(fused_proj, keys_proj, values_proj); + + if (_num_heads_kv == 1) { + if (values_padder) + values_padder->add_padding(fused_proj); + ops::Split(2, {_d_head, _d_head})(fused_proj, keys_proj, values_proj); + } else { + split_heads(fused_proj, 2 * _num_heads, values_padder); + ops::Split(1)(fused_proj, keys_proj, values_proj); + } if (cached_keys != nullptr) { *cached_keys = std::move(keys_proj); @@ -448,7 +455,14 @@ namespace ctranslate2 { if (queries_proj.dim(1) == 1 && cached_keys) beam_size = queries_proj.dim(0) / cached_keys->dim(0); - split_heads(queries_proj, _num_heads, queries_padder, beam_size); + + if (_num_heads_kv == 1) { + if (queries_padder) + queries_padder->add_padding(queries_proj); + queries_proj.reshape({queries_proj.dim(0) / beam_size, -1, _d_head}); + } else { + split_heads(queries_proj, _num_heads, queries_padder, beam_size); + } } else {