diff --git a/CMakeLists.txt b/CMakeLists.txt index 62fc33640..d672a4f95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,8 @@ set(SOURCES src/ops/mul.cc src/ops/multinomial.cc src/ops/multinomial_cpu.cc + src/ops/position_encodings_add.cc + src/ops/position_encodings_add_cpu.cc src/ops/quantize.cc src/ops/quantize_cpu.cc src/ops/relu.cc @@ -569,6 +571,7 @@ if (WITH_CUDA) src/ops/layer_norm_gpu.cu src/ops/mean_gpu.cu src/ops/multinomial_gpu.cu + src/ops/position_encodings_add_gpu.cu src/ops/rms_norm_gpu.cu src/ops/rotary_gpu.cu src/ops/softmax_gpu.cu diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 5778a028c..f4209a3db 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -33,6 +33,7 @@ namespace ctranslate2 { virtual void operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, + const StorageView* values_offsets, StorageView& output, StorageView* cached_keys = nullptr, StorageView* cached_values = nullptr, diff --git a/include/ctranslate2/layers/attention_layer.h b/include/ctranslate2/layers/attention_layer.h index e55ecc5de..8c43559f0 100644 --- a/include/ctranslate2/layers/attention_layer.h +++ b/include/ctranslate2/layers/attention_layer.h @@ -49,7 +49,10 @@ namespace ctranslate2 { const dim_t num_heads, const dim_t num_queries, const bool mask_future = false, - const bool multi_query = false); + const bool multi_query = false, + const dim_t step = 0, + const StorageView* offsets = nullptr, + StorageView* values_offsets = nullptr); protected: const bool _tensor_parallel; diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 137b926d3..d56ceb7c2 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -89,10 +89,14 @@ namespace ctranslate2 { // Base class for position encoders. class PositionEncoder : public Layer { public: - void operator()(StorageView& input, dim_t index = 0); - void operator()(const StorageView& input, StorageView& output, dim_t index = 0); + void operator()(const StorageView& input, + StorageView& output, + dim_t step = 0, + const StorageView* offsets = nullptr); protected: virtual const StorageView& get_position_encoding(dim_t max_time) = 0; + private: + ops::PositionEncodingsAdd _add_op; }; // Concrete position encoder loading encoding vectors from the model. diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 017f9b675..9fc0ba4b8 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -84,6 +84,7 @@ namespace ctranslate2 { void operator()(const StorageView& input, const StorageView* input_lengths, + const StorageView* input_offsets, const StorageView* memory, const StorageView* memory_lengths, StorageView* cached_self_attn_keys, diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index e9818cc4e..f88fbd710 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -126,7 +126,7 @@ namespace ctranslate2 { std::vector generate(StorageView features, - const std::vector>& prompts, + std::vector> prompts, const WhisperOptions& options); std::vector>> diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index 2a735e394..39c26206d 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -44,3 +44,4 @@ #include "awq/gemv.h" #include "awq/dequantize_awq.h" #include "sum.h" +#include "position_encodings_add.h" diff --git a/include/ctranslate2/ops/position_encodings_add.h b/include/ctranslate2/ops/position_encodings_add.h new file mode 100644 index 000000000..de255690f --- /dev/null +++ b/include/ctranslate2/ops/position_encodings_add.h @@ -0,0 +1,26 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + + class PositionEncodingsAdd : public Op { + public: + void operator()(const StorageView& input, + const StorageView& encodings, + StorageView& output, + const StorageView* offsets = nullptr, + const dim_t step = 0) const; + + private: + template + void compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const; + }; + + } +} diff --git a/include/ctranslate2/ops/softmax.h b/include/ctranslate2/ops/softmax.h index e9b1c08d1..c515677e6 100644 --- a/include/ctranslate2/ops/softmax.h +++ b/include/ctranslate2/ops/softmax.h @@ -13,11 +13,21 @@ namespace ctranslate2 { void operator()(StorageView& x) const; void operator()(const StorageView& x, StorageView& y) const override; void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const; - void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const; + void operator()(const StorageView& x, + const StorageView& lengths, + const StorageView& offsets, + StorageView& y) const; + void operator()(const StorageView& x, + const StorageView* lengths, + const StorageView* offsets, + StorageView& y) const; private: template - void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const; + void compute(const StorageView& input, + const StorageView* lengths, + const StorageView* offsets, + StorageView& output) const; bool _log; }; diff --git a/include/ctranslate2/padder.h b/include/ctranslate2/padder.h index ebcc02baf..e19f8508e 100644 --- a/include/ctranslate2/padder.h +++ b/include/ctranslate2/padder.h @@ -19,6 +19,11 @@ namespace ctranslate2 { const dim_t max_time = -1, const dim_t pad_batch_to_multiple = 1); + Padder(const StorageView& lengths, + const StorageView* offsets, + const dim_t max_time = -1, + const dim_t pad_batch_to_multiple = 1); + // Merge batch and time dimensions and remove padding. void remove_padding(StorageView& x) const; @@ -26,6 +31,11 @@ namespace ctranslate2 { void add_padding(StorageView& x) const; private: + void initialize(const StorageView& lengths, + const StorageView* offsets, + const dim_t max_time, + const dim_t pad_batch_to_multiple); + dim_t _batch_size; dim_t _max_time; StorageView _padded_to_flat; diff --git a/include/ctranslate2/primitives.h b/include/ctranslate2/primitives.h index 571121554..a292b941f 100644 --- a/include/ctranslate2/primitives.h +++ b/include/ctranslate2/primitives.h @@ -142,13 +142,16 @@ namespace ctranslate2 { dim_t length, dim_t vocabulary_size); - static void prepare_length_mask(const int32_t* lengths, - dim_t batch_size, - dim_t num_heads, - dim_t num_queries, - bool mask_future, - bool multi_query, - int32_t* mask); + static void prepare_mha_values_mask(const int32_t* lengths, + const int32_t* offsets, + dim_t batch_size, + dim_t num_heads, + dim_t num_queries, + bool mask_future, + bool multi_query, + dim_t step, + int32_t* values_lengths, + int32_t* values_offsets); template static void transpose_2d(const T* a, const dim_t* dims, T* b); diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 1fed8196d..3abf2c2b0 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -751,6 +751,32 @@ def teardown_class(cls): pytest.approx(0.062380101531744, abs=1e-3), ], ), + ( + "openai/whisper-tiny.en", + [ + ["<|startoftranscript|>"], + [ + "<|startofprev|>", + "ĠAnd", + "Ġthen", + "Ġthe", + "ĠPresident", + "Ġshouted", + ":", + "<|startoftranscript|>", + ], + ], + [ + " Mr. Quilter is the apostle of the middle classes, and we are glad" + " to welcome his gospel.", + " And so my fellow Americans ask not what your country can do for you," + " ask what you can do for your country.", + ], + [ + pytest.approx(0.02644546702504158, abs=1e-4), + pytest.approx(0.008309835568070412, abs=1e-3), + ], + ), ], ) def test_transformers_whisper( diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index c1f48553d..bc7ea7c5a 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -402,6 +402,7 @@ namespace ctranslate2 { template<> void softmax(const float* input, const int32_t* lengths, + const int32_t* offsets, float* output, dim_t batch_size, dim_t depth, @@ -410,24 +411,29 @@ namespace ctranslate2 { parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { for (dim_t i = begin; i < end; ++i) { + const dim_t start = offsets ? offsets[i] : 0; + const dim_t size = lengths ? lengths[i] : depth - start; + const dim_t offset = i * depth; const float* x = input + offset; float* y = output + offset; - dim_t size = depth; - if (lengths) { - size = lengths[i]; + // Directly set 0 in output for out of range positions. - // Directly set 0 in output for out of range positions. - for (dim_t j = size; j < depth; ++j) { + if (size <= 0) { + for (dim_t j = 0; j < depth; ++j) y[j] = 0; - } - - if (size == 0) { - continue; - } + continue; } + for (dim_t j = 0; j < start; ++j) + y[j] = 0; + for (dim_t j = start + size; j < depth; ++j) + y[j] = 0; + + x += start; + y += start; + const auto x_max = reduce_max(x, size); const auto vec_x_max = VecType::load(x_max); diff --git a/src/cpu/kernels.h b/src/cpu/kernels.h index 16296fc36..94fccecc2 100644 --- a/src/cpu/kernels.h +++ b/src/cpu/kernels.h @@ -67,6 +67,7 @@ namespace ctranslate2 { template void softmax(const float* input, const int32_t* lengths, + const int32_t* offsets, float* output, dim_t batch_size, dim_t depth, diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 5e0fd2999..4c7f7906f 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -414,21 +414,31 @@ namespace ctranslate2 { } template<> - void primitives::prepare_length_mask(const int32_t* lengths, - dim_t batch_size, - dim_t num_heads, - dim_t num_queries, - bool mask_future, - bool multi_query, - int32_t* mask) { + void primitives::prepare_mha_values_mask(const int32_t* lengths, + const int32_t* offsets, + dim_t batch_size, + dim_t num_heads, + dim_t num_queries, + bool mask_future, + bool multi_query, + dim_t step, + int32_t* values_lengths, + int32_t* values_offsets) { for (dim_t b = 0; b < batch_size; ++b) { - const auto length = lengths[b]; - auto* batch_mask = mask + b * num_heads * num_queries; - for (dim_t i = 0; i < num_heads * num_queries; ++i) { - batch_mask[i] = (mask_future - ? std::min(length, - int32_t((multi_query ? i / num_heads : i % num_queries) + 1)) - : length); + const auto offset = offsets ? offsets[b] : 0; + const auto length = lengths[b] + int32_t(step) - offset; + const auto batch_offset = b * num_heads * num_queries; + + for (dim_t i = batch_offset; i < batch_offset + num_heads * num_queries; ++i) { + if (mask_future) { + const int32_t time = step + (multi_query ? i / num_heads : i % num_queries); + values_lengths[i] = time < offset ? 0 : std::min(time - offset + 1, length); + } else { + values_lengths[i] = length; + } + + if (values_offsets) + values_offsets[i] = offset; } } } diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 9915bb12c..9b30499ff 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -283,36 +283,58 @@ namespace ctranslate2 { vocabulary_size); } - __global__ void prepare_length_mask_kernel(const int32_t* lengths, - cuda::index_t num_heads, - cuda::index_t num_queries, - bool mask_future, - bool multi_query, - int32_t* mask) { - const auto length = lengths[blockIdx.x]; - mask += blockIdx.x * num_heads * num_queries; - for (cuda::index_t i = threadIdx.x; i < num_heads * num_queries; i += blockDim.x) - mask[i] = (mask_future - ? min(length, int32_t((multi_query ? i / num_heads : i % num_queries) + 1)) - : length); - } - - template<> - void primitives::prepare_length_mask(const int32_t* lengths, - dim_t batch_size, - dim_t num_heads, - dim_t num_queries, - bool mask_future, - bool multi_query, - int32_t* mask) { + __global__ void prepare_mha_values_mask_kernel(const int32_t* lengths, + const int32_t* offsets, + cuda::index_t num_heads, + cuda::index_t num_queries, + bool mask_future, + bool multi_query, + cuda::index_t step, + int32_t* values_lengths, + int32_t* values_offsets) { + const auto offset = offsets ? offsets[blockIdx.x] : 0; + const auto length = lengths[blockIdx.x] + step - offset; + const auto batch_offset = blockIdx.x * num_heads * num_queries; + + for (cuda::index_t i = batch_offset + threadIdx.x; + i < batch_offset + num_heads * num_queries; + i += blockDim.x) { + + if (mask_future) { + const int32_t time = step + (multi_query ? i / num_heads : i % num_queries); + values_lengths[i] = time < offset ? 0 : min(time - offset + 1, length); + } else { + values_lengths[i] = length; + } + + if (values_offsets) + values_offsets[i] = offset; + } + } + + template<> + void primitives::prepare_mha_values_mask(const int32_t* lengths, + const int32_t* offsets, + dim_t batch_size, + dim_t num_heads, + dim_t num_queries, + bool mask_future, + bool multi_query, + dim_t step, + int32_t* values_lengths, + int32_t* values_offsets) { const dim_t blocks = std::min(batch_size, cuda::max_blocks); const dim_t threads = std::min(num_heads * num_queries, cuda::max_threads); - prepare_length_mask_kernel<<>>(lengths, - num_heads, - num_queries, - mask_future, - multi_query, - mask); + prepare_mha_values_mask_kernel<<>>( + lengths, + offsets, + num_heads, + num_queries, + mask_future, + multi_query, + step, + values_lengths, + values_offsets); } template<> diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 6ad344410..0632be8b0 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -181,6 +181,7 @@ namespace ctranslate2 { const StorageView& keys, const StorageView& values, const StorageView* values_lengths, + const StorageView* values_offsets, const StorageView* relative_position_keys, const StorageView* relative_asymmetric_position_keys, const StorageView* relative_position_values, @@ -267,7 +268,7 @@ namespace ctranslate2 { alibi->apply(output, queries_scale); StorageView attn(values.dtype(), values.device()); - ops::SoftMax()(output, values_lengths, attn); + ops::SoftMax()(output, values_lengths, values_offsets, attn); if (attention && !return_normalized_attention) save_attention(*attention, std::move(output), beam_size); @@ -337,6 +338,7 @@ namespace ctranslate2 { void MultiHeadAttention::operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, + const StorageView* values_offsets, StorageView& output, StorageView* cached_keys, StorageView* cached_values, @@ -472,6 +474,7 @@ namespace ctranslate2 { keys_proj, values_proj, values_lengths, + values_offsets, _relative_position_keys, _relative_asymmetric_position_keys, _relative_position_values, diff --git a/src/layers/attention_layer.cc b/src/layers/attention_layer.cc index c9ae67409..eeeb9b182 100644 --- a/src/layers/attention_layer.cc +++ b/src/layers/attention_layer.cc @@ -153,7 +153,10 @@ namespace ctranslate2 { const dim_t num_heads, const dim_t num_queries, const bool mask_future, - const bool multi_query) { + const bool multi_query, + const dim_t step, + const StorageView* offsets, + StorageView* values_offsets) { const Device device = lengths.device(); const dim_t batch_size = lengths.size(); StorageView mask(lengths.dtype(), device); @@ -163,13 +166,26 @@ namespace ctranslate2 { else mask.resize({batch_size, num_heads, num_queries}); - DEVICE_DISPATCH(device, (primitives::prepare_length_mask(lengths.data(), - batch_size, - num_heads, - num_queries, - mask_future, - multi_query, - mask.data()))); + if (offsets) { + if (!values_offsets) + throw std::runtime_error("Missing values_offsets output"); + values_offsets->resize_as(mask); + } + + DEVICE_DISPATCH( + device, + (primitives::prepare_mha_values_mask( + lengths.data(), + offsets ? offsets->data() : nullptr, + batch_size, + num_heads, + num_queries, + mask_future, + multi_query, + step, + mask.data(), + values_offsets ? values_offsets->data() : nullptr))); + return mask; } diff --git a/src/layers/common.cc b/src/layers/common.cc index c6d1cd0b5..3484e6c9c 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -4,7 +4,6 @@ #include "ctranslate2/ops/activation.h" #include "cpu/backend.h" -#include "dispatch.h" namespace ctranslate2 { namespace layers { @@ -148,34 +147,13 @@ namespace ctranslate2 { } - void PositionEncoder::operator()(StorageView& input, dim_t index) { + void PositionEncoder::operator()(const StorageView& input, + StorageView& output, + dim_t step, + const StorageView* offsets) { const dim_t time = input.dim(1); - const dim_t depth = input.dim(-1); - const dim_t max_time = time + index; - const StorageView& encodings = get_position_encoding(max_time); - const dim_t num_encodings = encodings.dim(0); - - if (max_time > num_encodings) - throw std::runtime_error("No position encodings are defined for positions >= " - + std::to_string(num_encodings) - + ", but got position " - + std::to_string(max_time - 1)); - if (depth != encodings.dim(1)) - throw std::invalid_argument("Shape mismatch: position encodings have depth " - + std::to_string(encodings.dim(1)) - + ", but the input has depth " - + std::to_string(depth)); - - DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), - primitives::add_batch_broadcast(encodings.data() + index * depth, - input.data(), - time * depth, - input.size())); - } - - void PositionEncoder::operator()(const StorageView& input, StorageView& output, dim_t index) { - output = input; - operator()(output, index); + const StorageView& encodings = get_position_encoding(step + time); + _add_op(input, encodings, output, offsets, step); } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 5ac5bfa36..a1b0cd0c5 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -84,6 +84,7 @@ namespace ctranslate2 { (*_self_attention)(input, input, lengths, + nullptr, context, nullptr, nullptr, @@ -135,6 +136,7 @@ namespace ctranslate2 { void TransformerDecoderLayer::operator()(const StorageView& input, const StorageView* input_length, + const StorageView* input_offsets, const StorageView* memory, const StorageView* memory_lengths, StorageView* cached_self_attn_keys, @@ -204,6 +206,7 @@ namespace ctranslate2 { (*_self_attention)(hidden, hidden, input_length, + input_offsets, attn, cached_self_attn_keys, cached_self_attn_values, @@ -228,6 +231,7 @@ namespace ctranslate2 { (*_self_attention)(input, input, input_length, + input_offsets, output, cached_self_attn_keys, cached_self_attn_values, @@ -243,6 +247,7 @@ namespace ctranslate2 { (*_encoder_attention)(output, *memory, memory_lengths, + nullptr, context, cached_attn_keys, cached_attn_values, @@ -326,7 +331,7 @@ namespace ctranslate2 { if (_embeddings_scale) ops::Mul()(input, *_embeddings_scale, input); if (_position_encoder) - (*_position_encoder)(input); + (*_position_encoder)(input, input); if (_layernorm_embedding) (*_layernorm_embedding)(input, input); @@ -522,6 +527,15 @@ namespace ctranslate2 { const Device device = ids.device(); const bool is_sequence = ids.rank() > 1; + const StorageView* left_padding = nullptr; + + { + const auto it = state.find("offsets"); + if (it != state.end()) { + left_padding = &(it->second); + } + } + StorageView layer_in(dtype, device); StorageView layer_out(dtype, device); @@ -537,7 +551,7 @@ namespace ctranslate2 { if (layer_in.rank() == 2) layer_in.expand_dims(1); if (_position_encoder) - (*_position_encoder)(layer_in, std::max(step, dim_t(0))); + (*_position_encoder)(layer_in, layer_in, std::max(step, dim_t(0)), left_padding); if (_layernorm_embedding) (*_layernorm_embedding)(layer_in, layer_in); @@ -555,34 +569,38 @@ namespace ctranslate2 { std::unique_ptr input_lengths; std::unique_ptr input_lengths_mask; - if (is_sequence && !lengths) { + if ((is_sequence || left_padding) && !lengths) { input_lengths = std::make_unique(Shape{ids.dim(0)}, int32_t(max_time), device); lengths = input_lengths.get(); } - bool multi_query = _layers.front()->get_self_attention().multi_query(); + std::unique_ptr left_padding_self_attn; if (lengths) { if (allow_padding_removal) { - input_padder = std::make_unique(*lengths, max_time); + input_padder = std::make_unique(*lengths, left_padding, max_time); input_padder->remove_padding(layer_in); } + if (left_padding) { + left_padding_self_attn = std::make_unique(left_padding->dtype(), + left_padding->device()); + } + dim_t num_heads = _num_heads; if (_tensor_parallel) { num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); } - StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( + StorageView lengths_mask = layers::MultiHeadAttention::prepare_values_mask( *lengths, num_heads, max_time, /*mask_future=*/true, - multi_query); - - - if (step > 0) - ops::Add()(lengths_mask, StorageView(int32_t(step)), lengths_mask); + multi_query, + std::max(step, dim_t(0)), + left_padding, + left_padding_self_attn.get()); input_lengths_mask = std::make_unique(std::move(lengths_mask)); } @@ -685,6 +703,7 @@ namespace ctranslate2 { (*_layers[l])(*layer_in_chunk, input_lengths_mask.get(), + left_padding_self_attn.get(), memory, memory_lengths_mask.get(), cached_self_attn_keys, diff --git a/src/layers/whisper.cc b/src/layers/whisper.cc index aa5dece77..d74373b23 100644 --- a/src/layers/whisper.cc +++ b/src/layers/whisper.cc @@ -50,7 +50,7 @@ namespace ctranslate2 { _gelu(output, output); _transpose(output, input); - _position_embedding(input); + _position_embedding(input, input); for (const auto& layer : _layers) { (*layer)(input, nullptr, output); diff --git a/src/models/whisper.cc b/src/models/whisper.cc index 7cdf2dc5b..ada22fef4 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -150,40 +150,61 @@ namespace ctranslate2 { static size_t get_prompt_length(const std::vector& prompt, const size_t sot_id, - const size_t no_timestamps_id) { - size_t index = get_sot_index(prompt, sot_id); + const size_t no_timestamps_id, + const size_t* sot_index = nullptr) { + size_t index = sot_index ? *sot_index : get_sot_index(prompt, sot_id); while (index < prompt.size() && prompt[index] >= sot_id && prompt[index] <= no_timestamps_id) index++; return index; } - static void check_prompts(const std::vector>& prompts, + // Add left padding to align prompts on the <|startoftranscript|> token. + static void align_prompts(std::vector>& prompts, const size_t sot_id, + const size_t pad_id, const size_t no_timestamps_id, + std::vector& offsets, size_t& sot_index, size_t& prompt_length) { - bool first = true; + const size_t batch_size = prompts.size(); - for (const auto& prompt : prompts) { - const auto batch_sot_index = get_sot_index(prompt, sot_id); - const auto batch_prompt_length = get_prompt_length(prompt, sot_id, no_timestamps_id); + std::vector sot_indices; + sot_indices.reserve(batch_size); + sot_index = 0; - if (first) { - sot_index = batch_sot_index; + for (size_t i = 0; i < batch_size; ++i) { + const auto batch_sot_index = get_sot_index(prompts[i], sot_id); + sot_indices.push_back(batch_sot_index); + sot_index = std::max(sot_index, batch_sot_index); + } + + bool no_padding = true; + offsets.reserve(batch_size); + + for (size_t i = 0; i < batch_size; ++i) { + const auto offset = sot_index - sot_indices[i]; + offsets.push_back(offset); + + if (offset > 0) { + prompts[i].insert(prompts[i].begin(), offset, pad_id); + no_padding = false; + } + + const auto batch_prompt_length = get_prompt_length(prompts[i], + sot_id, + no_timestamps_id, + &sot_index); + + if (i == 0) { prompt_length = batch_prompt_length; - } else if (batch_sot_index != sot_index) { - throw std::invalid_argument("The generate method currently requires the " - "<|startoftranscript|> token to be at the same position " - "in all batches. To work around this limitation, " - "simply adapt the number of previous text tokens in each " - "batch."); } else if (batch_prompt_length != prompt_length) { throw std::invalid_argument("The generate method currently requires each batch to have " "the same number of task tokens after <|startoftranscript|>."); } - - first = false; } + + if (no_padding) + offsets.clear(); } class ApplyTimestampRules; @@ -228,7 +249,7 @@ namespace ctranslate2 { std::vector WhisperReplica::generate(StorageView features, - const std::vector>& prompts, + std::vector> prompts, const WhisperOptions& options) { PROFILE("WhisperReplica::generate"); if (prompts.empty()) @@ -238,9 +259,10 @@ namespace ctranslate2 { const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); #endif + std::vector offsets; size_t sot_index = 0; size_t prompt_length = 0; // Length of the prompt before the text tokens. - check_prompts(prompts, _sot_id, _no_timestamps_id, sot_index, prompt_length); + align_prompts(prompts, _sot_id, _eot_id, _no_timestamps_id, offsets, sot_index, prompt_length); const auto& vocabulary = _model->get_vocabulary(); const auto scoped_device_setter = _model->get_scoped_device_setter(); @@ -248,6 +270,9 @@ namespace ctranslate2 { layers::DecoderState state = _decoder->initial_state(); state.emplace("memory", maybe_encode(std::move(features))); + if (!offsets.empty()) + state.emplace("offsets", StorageView({dim_t(offsets.size())}, offsets, _decoder->device())); + _decoder->update_output_layer(_model->preferred_size_multiple()); const bool sot_is_start_token = (sot_index == prompt_length - 1); @@ -511,7 +536,7 @@ namespace ctranslate2 { std::vector(num_frames.begin(), num_frames.end()), device); const StorageView frame_sizes_mask( - layers::MultiHeadAttention::prepare_length_mask(frame_sizes, + layers::MultiHeadAttention::prepare_values_mask(frame_sizes, attention_weights.dim(1), attention_weights.dim(2))); @@ -685,7 +710,7 @@ namespace ctranslate2 { prompts = std::move(prompts), options = std::move(options)] (WhisperReplica& replica) mutable { - return replica.generate(std::move(features), prompts, options); + return replica.generate(std::move(features), std::move(prompts), options); }, batch_size); } diff --git a/src/ops/position_encodings_add.cc b/src/ops/position_encodings_add.cc new file mode 100644 index 000000000..280f3f05f --- /dev/null +++ b/src/ops/position_encodings_add.cc @@ -0,0 +1,50 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + void PositionEncodingsAdd::operator()(const StorageView& input, + const StorageView& encodings, + StorageView& output, + const StorageView* offsets, + const dim_t step) const { + PROFILE("PositionEncodingsAdd"); + + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + const dim_t max_time = time + step; + + if (max_time > encodings.dim(0)) + throw std::runtime_error("No position encodings are defined for positions >= " + + std::to_string(encodings.dim(0)) + + ", but got position " + + std::to_string(max_time - 1)); + + if (depth != encodings.dim(1)) + throw std::invalid_argument("Shape mismatch: position encodings have depth " + + std::to_string(encodings.dim(1)) + + ", but the input has depth " + + std::to_string(depth)); + + output.resize_as(input); + + if (offsets) { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (compute(step, offsets, input, encodings, output))); + + } else { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (primitives::add_batch_broadcast(encodings.data() + step * depth, + input.data(), + output.data(), + time * depth, + input.size()))); + } + } + + } +} diff --git a/src/ops/position_encodings_add_cpu.cc b/src/ops/position_encodings_add_cpu.cc new file mode 100644 index 000000000..ef03b6dde --- /dev/null +++ b/src/ops/position_encodings_add_cpu.cc @@ -0,0 +1,48 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "cpu/parallel.h" + +namespace ctranslate2 { + namespace ops { + + template + void PositionEncodingsAdd::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const { + const dim_t batch_size = input.dim(0); + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + + cpu::parallel_for(0, batch_size * time, 1, [&](const dim_t begin, const dim_t end) { + for (dim_t i = begin; i < end; ++i) { + const dim_t b = i / time; + const dim_t t = i % time; + + const dim_t offset = offsets ? offsets->at(b) : 0; + const dim_t encoding_offset = t - offset + step; + + if (encoding_offset < 0) + continue; + + primitives::add(encodings.index({encoding_offset, 0}), + input.index({b, t, 0}), + output.index({b, t, 0}), + depth); + } + }); + } + +#define DECLARE_IMPL(T) \ + template void \ + PositionEncodingsAdd::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + + } +} diff --git a/src/ops/position_encodings_add_gpu.cu b/src/ops/position_encodings_add_gpu.cu new file mode 100644 index 000000000..e5ad79194 --- /dev/null +++ b/src/ops/position_encodings_add_gpu.cu @@ -0,0 +1,73 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "type_dispatch.h" +#include "cuda/helpers.h" + +namespace ctranslate2 { + namespace ops { + + template + __global__ void position_encodings_add_kernel(const T* input, + const T* encodings, + T* output, + const int32_t* offsets, + cuda::index_t step, + cuda::index_t max_time, + cuda::index_t depth, + const AddFunc& add_func) { + const cuda::index_t batch = blockIdx.x / max_time; + const cuda::index_t time = blockIdx.x % max_time; + + const int32_t offset = offsets ? offsets[batch] : 0; + const int32_t encoding_offset = time - offset + step; + + if (encoding_offset < 0) + return; + + input += blockIdx.x * depth; + output += blockIdx.x * depth; + encodings += encoding_offset * depth; + + for (cuda::index_t i = threadIdx.x; i < depth; i += blockDim.x) { + output[i] = add_func(input[i], encodings[i]); + } + } + + template + void PositionEncodingsAdd::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const { + const dim_t batch_size = input.dim(0); + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + + const dim_t blocks = std::min(batch_size * time, cuda::max_blocks); + const dim_t threads = std::min(depth, cuda::max_threads); + + position_encodings_add_kernel<<>>( + cuda::device_cast(input.data()), + cuda::device_cast(encodings.data()), + cuda::device_cast(output.data()), + offsets ? offsets->data() : nullptr, + step, + time, + depth, + cuda::plus>()); + } + +#define DECLARE_IMPL(T) \ + template void \ + PositionEncodingsAdd::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 638e02ae5..eb4052f1e 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -14,36 +14,51 @@ namespace ctranslate2 { } void SoftMax::operator()(StorageView& x) const { - operator()(x, nullptr, x); + operator()(x, nullptr, nullptr, x); } void SoftMax::operator()(const StorageView& x, StorageView& y) const { - operator()(x, nullptr, y); + operator()(x, nullptr, nullptr, y); } void SoftMax::operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const { - operator()(x, &lengths, y); + operator()(x, &lengths, nullptr, y); } - void SoftMax::operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const { + void SoftMax::operator()(const StorageView& x, + const StorageView& lengths, + const StorageView& offsets, + StorageView& y) const { + operator()(x, &lengths, &offsets, y); + } + + void SoftMax::operator()(const StorageView& x, + const StorageView* lengths, + const StorageView* offsets, + StorageView& y) const { PROFILE(_log ? "LogSoftMax" : "SoftMax"); y.resize_as(x); const dim_t depth = x.dim(-1); + const dim_t batch_size = x.size() / depth; if (depth == 0) return; - if (lengths) { - const dim_t batch_size = x.size() / depth; - if (lengths->size() != batch_size) - throw std::invalid_argument("Length mask has size " - + std::to_string(lengths->size()) - + " which is different than the current batch size " - + std::to_string(batch_size)); - } + if (lengths && lengths->size() != batch_size) + throw std::invalid_argument("Length mask has size " + + std::to_string(lengths->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); + + if (offsets && offsets->size() != batch_size) + throw std::invalid_argument("Offsets input has size " + + std::to_string(offsets->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); - DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), (compute(x, lengths, y))); + DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), + (compute(x, lengths, offsets, y))); } } diff --git a/src/ops/softmax_cpu.cc b/src/ops/softmax_cpu.cc index be51630b7..77f3d20b2 100644 --- a/src/ops/softmax_cpu.cc +++ b/src/ops/softmax_cpu.cc @@ -8,12 +8,14 @@ namespace ctranslate2 { template void SoftMax::compute(const StorageView& input, const StorageView* lengths, + const StorageView* offsets, StorageView& output) const { const dim_t depth = input.dim(-1); const dim_t batch_size = input.size() / depth; CPU_ISA_DISPATCH((cpu::softmax(input.data(), lengths ? lengths->data() : nullptr, + offsets ? offsets->data() : nullptr, output.data(), batch_size, depth, @@ -24,6 +26,7 @@ namespace ctranslate2 { template void \ SoftMax::compute(const StorageView& input, \ const StorageView* lengths, \ + const StorageView* offsets, \ StorageView& output) const; DECLARE_IMPL(float) diff --git a/src/ops/softmax_gpu.cu b/src/ops/softmax_gpu.cu index abee00f71..34e570bcc 100644 --- a/src/ops/softmax_gpu.cu +++ b/src/ops/softmax_gpu.cu @@ -13,11 +13,13 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y); template void SoftMax::compute(const StorageView& input, const StorageView* lengths, + const StorageView* offsets, StorageView& output) const { const dim_t depth = input.dim(-1); const dim_t batch_size = input.size() / depth; @@ -27,6 +29,7 @@ namespace ctranslate2 { batch_size, depth, lengths ? lengths->data() : nullptr, + offsets ? offsets->data() : nullptr, output.data()); } @@ -34,6 +37,7 @@ namespace ctranslate2 { template void \ SoftMax::compute(const StorageView& input, \ const StorageView* lengths, \ + const StorageView* offsets, \ StorageView& output) const; DECLARE_IMPL(float) @@ -197,7 +201,8 @@ namespace at { cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, const index_t classes, - const length_t *lengths) + const length_t *lengths, + const length_t *offsets) { extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); @@ -207,15 +212,27 @@ namespace at { input += row * classes; output += row * classes; - index_t size = classes; - if (lengths) - { - // Directly set 0 in output for out of range positions. - size = lengths[row]; - for (index_t i = size + threadIdx.x; i < classes; i += blockDim.x) + const index_t start = offsets ? offsets[row] : 0; + const index_t size = lengths ? lengths[row] : classes - start; + const index_t end = start + size; + + if (size <= 0) { + for (index_t i = threadIdx.x; i < classes; i += blockDim.x) output[i] = 0.f; + return; + } + + if (start > 0 || end < classes) { + // Directly set 0 in output for out of range positions. + for (index_t i = threadIdx.x; i < classes; i += blockDim.x) { + if (i < start || i >= end) + output[i] = 0.f; + } } + input += start; + output += start; + // find the max accscalar_t threadMax = ctranslate2::cuda::ilp_reduce( input, size, MaxFloat(), -max_float); @@ -245,6 +262,7 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y) { const dim3 grid(rows); const dim3 block(cuda::get_block_size(cols)); @@ -252,7 +270,8 @@ namespace ctranslate2 { <<>>(y, x, cols, - lengths); + lengths, + offsets); } template @@ -262,13 +281,14 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y) { if (log_softmax) softmax_kernel_impl, at::native::LogSoftMaxForwardEpilogue>( - stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y)); + stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y)); else softmax_kernel_impl, at::native::SoftMaxForwardEpilogue>( - stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y)); + stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y)); } } diff --git a/src/padder.cc b/src/padder.cc index 6ce37daf2..4bf1692f6 100644 --- a/src/padder.cc +++ b/src/padder.cc @@ -6,9 +6,28 @@ namespace ctranslate2 { Padder::Padder(const StorageView& lengths, const dim_t max_time, - const dim_t pad_batch_to_multiple) - : _batch_size(lengths.size()) { - const std::vector lengths_vec = lengths.to_vector(); + const dim_t pad_batch_to_multiple) { + initialize(lengths, nullptr, max_time, pad_batch_to_multiple); + } + + Padder::Padder(const StorageView& lengths, + const StorageView* offsets, + const dim_t max_time, + const dim_t pad_batch_to_multiple) { + initialize(lengths, offsets, max_time, pad_batch_to_multiple); + } + + void Padder::initialize(const StorageView& lengths, + const StorageView* offsets, + const dim_t max_time, + const dim_t pad_batch_to_multiple) { + _batch_size = lengths.size(); + + std::vector lengths_vec = lengths.to_vector(); + std::vector offsets_vec(_batch_size, 0); + if (offsets) + offsets_vec = offsets->to_vector(); + if (max_time < 0) _max_time = *std::max_element(lengths_vec.begin(), lengths_vec.end()); else @@ -32,13 +51,22 @@ namespace ctranslate2 { for (dim_t i = 0; i < _batch_size; ++i) { const dim_t length = lengths_vec[i]; - for (dim_t t = 0; t < length; ++t) { + const dim_t start = offsets_vec[i]; + const dim_t end = start + length; + + for (dim_t t = 0; t < start; ++t) { + flat_to_padded.push_back(flat_offset); + } + + for (dim_t t = start; t < end; ++t) { padded_to_flat.push_back(padded_offset + t); - flat_to_padded.push_back(flat_offset + t); + flat_to_padded.push_back(flat_offset + t - start); } - for (dim_t t = length; t < _max_time; ++t) { - flat_to_padded.push_back(flat_offset + length - 1); + + for (dim_t t = end; t < _max_time; ++t) { + flat_to_padded.push_back(flat_offset + end - 1); } + padded_offset += _max_time; flat_offset += length; } diff --git a/tests/layers_test.cc b/tests/layers_test.cc index cbbaa2d72..7b20ba2ae 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -206,6 +206,23 @@ TEST(LayerTest, Padder) { expect_storage_eq(x, w_padding); } +TEST(LayerTest, PadderOffsets) { + const StorageView lengths({3}, std::vector{2, 3, 1}); + const StorageView offsets({3}, std::vector{2, 1, 3}); + const Padder padder(lengths, &offsets, /*max_time=*/4); + + StorageView x({3, 4}, std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + const StorageView wo_padding({6}, std::vector{2, 3, 5, 6, 7, 11}); + const StorageView w_padding({3, 4}, std::vector{2, 2, 2, 3, 5, 5, 6, 7, 11, 11, 11, 11}); + + padder.remove_padding(x); + ASSERT_EQ(x.rank(), 1); + expect_storage_eq(x, wo_padding); + padder.add_padding(x); + ASSERT_EQ(x.rank(), 2); + expect_storage_eq(x, w_padding); +} + TEST(LayerTest, PadderToMultiple) { const StorageView lengths({3}, std::vector{2, 3, 1}); const Padder padder(lengths, /*max_time=*/4, /*pad_batch_to_multiple=*/8); @@ -233,6 +250,36 @@ TEST(LayerTest, PadderIgnore) { expect_storage_eq(x, original); } +TEST_P(LayerDeviceFPTest, PositionEncoderOffset) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + + layers::SinusoidalPositionEncoder position_encoder(4, dtype, device); + + StorageView offsets({2}, std::vector{3, 1}, device); + dim_t step = 5; + + StorageView expected_encodings(dtype, device); + + { + StorageView zero({2, 5, 4}, 0.f, device); + StorageView encodings(dtype, device); + position_encoder(zero.to(dtype), encodings); + + StorageView position_ids({2, 1}, std::vector{2, 4}, device); + ops::Gather(/*axis=*/1, /*batch_dims=*/1)(encodings, position_ids, expected_encodings); + } + + { + StorageView zero({2, 1, 4}, 0.f, device); + StorageView encodings(dtype, device); + position_encoder(zero.to(dtype), encodings, step, &offsets); + + expect_storage_eq(encodings.to_float32(), expected_encodings.to_float32(), error); + } +} + TEST(LayerTest, PositionEncoderNoSharedState) { // Test case for issue: http://forum.opennmt.net/t/ctranslate2-c-api-returns-strange-results-when-initializing-2-models/3208 layers::SinusoidalPositionEncoder position_encoder_1(4); @@ -243,7 +290,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) { {1, 1, 4}, std::vector{0.1, -2.3, 0.5, 1.2}); StorageView expected( {1, 1, 4}, std::vector{0.941471, -2.2999, 1.0403, 2.2}); - position_encoder_1(input); + position_encoder_1(input, input); expect_storage_eq(input, expected, 1e-5); } @@ -252,7 +299,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) { {1, 1, 6}, std::vector{-0.2, -1.3, 0.1, -0.6, 2.0, 1.1}); StorageView expected( {1, 1, 6}, std::vector{0.641471, -1.29, 0.1001, -0.0596977, 2.99995, 2.1}); - position_encoder_2(input); + position_encoder_2(input, input); expect_storage_eq(input, expected, 1e-5); } } diff --git a/tests/ops_test.cc b/tests/ops_test.cc index c9369fa67..35ec8598a 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -661,6 +661,23 @@ TEST_P(OpDeviceFPTest, MaskedSoftMax) { expect_storage_eq(y.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, MaskedSoftMaxLeftPadding) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x({2, 5}, std::vector{ + 0.0, -0.2, 3.0, 1.2, -1.1, + 4.6, 3.3, 0.2, -1.6, 1.0}, device); + StorageView lengths({2}, std::vector{3, 4}, device); + StorageView offsets({2}, std::vector{1, 0}, device); + StorageView expected({2, 5}, std::vector{ + 0, 0.033797, 0.829145, 0.137056, 0, + 0.777098, 0.211783, 0.009540, 0.001577, 0}, device); + StorageView y(dtype, device); + ops::SoftMax()(x.to(dtype), lengths, offsets, y); + expect_storage_eq(y.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; @@ -680,7 +697,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { 0.8421174, 0.9135181, 0.77135813 }, device); StorageView lengths({2}, std::vector{3, 2}, device); - StorageView mask = layers::MultiHeadAttention::prepare_length_mask(lengths, 2, 3, true); + StorageView mask = layers::MultiHeadAttention::prepare_values_mask(lengths, 2, 3, true); StorageView expected({2, 2, 3, 3}, std::vector{ 1, 0, 0, 0.28861094, 0.71138906, 0,