diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index b9331b484..f4209a3db 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -30,7 +30,7 @@ namespace ctranslate2 { Alibi* alibi = nullptr); DataType output_type() const override; dim_t output_size() const override; - void operator()(const StorageView& queries, + virtual void operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, const StorageView* values_offsets, diff --git a/include/ctranslate2/layers/attention_layer.h b/include/ctranslate2/layers/attention_layer.h index 8c43559f0..6ad9cffa9 100644 --- a/include/ctranslate2/layers/attention_layer.h +++ b/include/ctranslate2/layers/attention_layer.h @@ -29,6 +29,7 @@ namespace ctranslate2 { virtual void operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, + const StorageView* values_offsets, StorageView& output, StorageView* cached_keys = nullptr, StorageView* cached_values = nullptr, diff --git a/include/ctranslate2/layers/flash_attention.h b/include/ctranslate2/layers/flash_attention.h index 3aa1db68f..b1812f2bc 100644 --- a/include/ctranslate2/layers/flash_attention.h +++ b/include/ctranslate2/layers/flash_attention.h @@ -21,6 +21,7 @@ namespace ctranslate2 { void operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, + const StorageView* values_offsets, StorageView& output, StorageView* cached_keys = nullptr, StorageView* cached_values = nullptr, diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index a1b0cd0c5..cc5b6512e 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -165,6 +165,7 @@ namespace ctranslate2 { (*_self_attention)(hidden, hidden, input_length, + input_offsets, context, cached_self_attn_keys, cached_self_attn_values, @@ -576,6 +577,8 @@ namespace ctranslate2 { std::unique_ptr left_padding_self_attn; + bool multi_query = _layers.front()->get_self_attention().multi_query(); + if (lengths) { if (allow_padding_removal) { input_padder = std::make_unique(*lengths, left_padding, max_time); @@ -592,7 +595,7 @@ namespace ctranslate2 { num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); } - StorageView lengths_mask = layers::MultiHeadAttention::prepare_values_mask( + StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( *lengths, num_heads, max_time, diff --git a/src/models/whisper.cc b/src/models/whisper.cc index ada22fef4..3c5f1a7cd 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -536,7 +536,7 @@ namespace ctranslate2 { std::vector(num_frames.begin(), num_frames.end()), device); const StorageView frame_sizes_mask( - layers::MultiHeadAttention::prepare_values_mask(frame_sizes, + layers::MultiHeadAttention::prepare_length_mask(frame_sizes, attention_weights.dim(1), attention_weights.dim(2))); diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 35ec8598a..2af7e2cd8 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -697,7 +697,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { 0.8421174, 0.9135181, 0.77135813 }, device); StorageView lengths({2}, std::vector{3, 2}, device); - StorageView mask = layers::MultiHeadAttention::prepare_values_mask(lengths, 2, 3, true); + StorageView mask = layers::MultiHeadAttention::prepare_length_mask(lengths, 2, 3, true); StorageView expected({2, 2, 3, 3}, std::vector{ 1, 0, 0, 0.28861094, 0.71138906, 0,