Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Sep 18, 2024
1 parent 62eaf33 commit df5491d
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace ctranslate2 {
Alibi* alibi = nullptr);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& queries,
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace ctranslate2 {
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace ctranslate2 {
void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand Down
5 changes: 4 additions & 1 deletion src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ namespace ctranslate2 {
(*_self_attention)(hidden,
hidden,
input_length,
input_offsets,
context,
cached_self_attn_keys,
cached_self_attn_values,
Expand Down Expand Up @@ -576,6 +577,8 @@ namespace ctranslate2 {

std::unique_ptr<StorageView> left_padding_self_attn;

bool multi_query = _layers.front()->get_self_attention().multi_query();

if (lengths) {
if (allow_padding_removal) {
input_padder = std::make_unique<Padder>(*lengths, left_padding, max_time);
Expand All @@ -592,7 +595,7 @@ namespace ctranslate2 {
num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
}

StorageView lengths_mask = layers::MultiHeadAttention::prepare_values_mask(
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
num_heads,
max_time,
Expand Down
2 changes: 1 addition & 1 deletion src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ namespace ctranslate2 {
std::vector<int32_t>(num_frames.begin(), num_frames.end()),
device);
const StorageView frame_sizes_mask(
layers::MultiHeadAttention::prepare_values_mask(frame_sizes,
layers::MultiHeadAttention::prepare_length_mask(frame_sizes,
attention_weights.dim(1),
attention_weights.dim(2)));

Expand Down
2 changes: 1 addition & 1 deletion tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) {
0.8421174, 0.9135181, 0.77135813
}, device);
StorageView lengths({2}, std::vector<int32_t>{3, 2}, device);
StorageView mask = layers::MultiHeadAttention::prepare_values_mask(lengths, 2, 3, true);
StorageView mask = layers::MultiHeadAttention::prepare_length_mask(lengths, 2, 3, true);
StorageView expected({2, 2, 3, 3}, std::vector<float>{
1, 0, 0,
0.28861094, 0.71138906, 0,
Expand Down

0 comments on commit df5491d

Please sign in to comment.