Skip to content

Commit

Permalink
Merge remote-tracking branch 'gk/whisper-batch-prompt'
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Sep 18, 2024
2 parents cb16c8e + 9c2c06d commit 25ef314
Show file tree
Hide file tree
Showing 31 changed files with 637 additions and 168 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ set(SOURCES
src/ops/mul.cc
src/ops/multinomial.cc
src/ops/multinomial_cpu.cc
src/ops/position_encodings_add.cc
src/ops/position_encodings_add_cpu.cc
src/ops/quantize.cc
src/ops/quantize_cpu.cc
src/ops/relu.cc
Expand Down Expand Up @@ -569,6 +571,7 @@ if (WITH_CUDA)
src/ops/layer_norm_gpu.cu
src/ops/mean_gpu.cu
src/ops/multinomial_gpu.cu
src/ops/position_encodings_add_gpu.cu
src/ops/rms_norm_gpu.cu
src/ops/rotary_gpu.cu
src/ops/softmax_gpu.cu
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace ctranslate2 {
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand Down
5 changes: 4 additions & 1 deletion include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ namespace ctranslate2 {
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);
const bool multi_query = false,
const dim_t step = 0,
const StorageView* offsets = nullptr,
StorageView* values_offsets = nullptr);

protected:
const bool _tensor_parallel;
Expand Down
8 changes: 6 additions & 2 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ namespace ctranslate2 {
// Base class for position encoders.
class PositionEncoder : public Layer {
public:
void operator()(StorageView& input, dim_t index = 0);
void operator()(const StorageView& input, StorageView& output, dim_t index = 0);
void operator()(const StorageView& input,
StorageView& output,
dim_t step = 0,
const StorageView* offsets = nullptr);
protected:
virtual const StorageView& get_position_encoding(dim_t max_time) = 0;
private:
ops::PositionEncodingsAdd _add_op;
};

// Concrete position encoder loading encoding vectors from the model.
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace ctranslate2 {

void operator()(const StorageView& input,
const StorageView* input_lengths,
const StorageView* input_offsets,
const StorageView* memory,
const StorageView* memory_lengths,
StorageView* cached_self_attn_keys,
Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace ctranslate2 {

std::vector<WhisperGenerationResult>
generate(StorageView features,
const std::vector<std::vector<size_t>>& prompts,
std::vector<std::vector<size_t>> prompts,
const WhisperOptions& options);

std::vector<std::vector<std::pair<std::string, float>>>
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
#include "awq/gemv.h"
#include "awq/dequantize_awq.h"
#include "sum.h"
#include "position_encodings_add.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/position_encodings_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class PositionEncodingsAdd : public Op {
public:
void operator()(const StorageView& input,
const StorageView& encodings,
StorageView& output,
const StorageView* offsets = nullptr,
const dim_t step = 0) const;

private:
template <Device D, typename T>
void compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& encodings,
StorageView& output) const;
};

}
}
14 changes: 12 additions & 2 deletions include/ctranslate2/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ namespace ctranslate2 {
void operator()(StorageView& x) const;
void operator()(const StorageView& x, StorageView& y) const override;
void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const;
void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const;
void operator()(const StorageView& x,
const StorageView& lengths,
const StorageView& offsets,
StorageView& y) const;
void operator()(const StorageView& x,
const StorageView* lengths,
const StorageView* offsets,
StorageView& y) const;

private:
template <Device D, typename T>
void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const;
void compute(const StorageView& input,
const StorageView* lengths,
const StorageView* offsets,
StorageView& output) const;

bool _log;
};
Expand Down
10 changes: 10 additions & 0 deletions include/ctranslate2/padder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,23 @@ namespace ctranslate2 {
const dim_t max_time = -1,
const dim_t pad_batch_to_multiple = 1);

Padder(const StorageView& lengths,
const StorageView* offsets,
const dim_t max_time = -1,
const dim_t pad_batch_to_multiple = 1);

// Merge batch and time dimensions and remove padding.
void remove_padding(StorageView& x) const;

// Split first dimension into batch and time dimensions and add padding.
void add_padding(StorageView& x) const;

private:
void initialize(const StorageView& lengths,
const StorageView* offsets,
const dim_t max_time,
const dim_t pad_batch_to_multiple);

dim_t _batch_size;
dim_t _max_time;
StorageView _padded_to_flat;
Expand Down
17 changes: 10 additions & 7 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,16 @@ namespace ctranslate2 {
dim_t length,
dim_t vocabulary_size);

static void prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask);
static void prepare_mha_values_mask(const int32_t* lengths,
const int32_t* offsets,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
dim_t step,
int32_t* values_lengths,
int32_t* values_offsets);

template <typename T>
static void transpose_2d(const T* a, const dim_t* dims, T* b);
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,32 @@ def teardown_class(cls):
pytest.approx(0.062380101531744, abs=1e-3),
],
),
(
"openai/whisper-tiny.en",
[
["<|startoftranscript|>"],
[
"<|startofprev|>",
"ĠAnd",
"Ġthen",
"Ġthe",
"ĠPresident",
"Ġshouted",
":",
"<|startoftranscript|>",
],
],
[
" Mr. Quilter is the apostle of the middle classes, and we are glad"
" to welcome his gospel.",
" And so my fellow Americans ask not what your country can do for you,"
" ask what you can do for your country.",
],
[
pytest.approx(0.02644546702504158, abs=1e-4),
pytest.approx(0.008309835568070412, abs=1e-3),
],
),
],
)
def test_transformers_whisper(
Expand Down
26 changes: 16 additions & 10 deletions src/cpu/kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ namespace ctranslate2 {
template<>
void softmax<TARGET_ISA>(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand All @@ -410,24 +411,29 @@ namespace ctranslate2 {

parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
for (dim_t i = begin; i < end; ++i) {
const dim_t start = offsets ? offsets[i] : 0;
const dim_t size = lengths ? lengths[i] : depth - start;

const dim_t offset = i * depth;
const float* x = input + offset;
float* y = output + offset;

dim_t size = depth;
if (lengths) {
size = lengths[i];
// Directly set 0 in output for out of range positions.

// Directly set 0 in output for out of range positions.
for (dim_t j = size; j < depth; ++j) {
if (size <= 0) {
for (dim_t j = 0; j < depth; ++j)
y[j] = 0;
}

if (size == 0) {
continue;
}
continue;
}

for (dim_t j = 0; j < start; ++j)
y[j] = 0;
for (dim_t j = start + size; j < depth; ++j)
y[j] = 0;

x += start;
y += start;

const auto x_max = reduce_max<TARGET_ISA>(x, size);
const auto vec_x_max = VecType::load(x_max);

Expand Down
1 change: 1 addition & 0 deletions src/cpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace ctranslate2 {
template <CpuIsa ISA>
void softmax(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand Down
38 changes: 24 additions & 14 deletions src/cpu/primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,21 +414,31 @@ namespace ctranslate2 {
}

template<>
void primitives<Device::CPU>::prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask) {
void primitives<Device::CPU>::prepare_mha_values_mask(const int32_t* lengths,
const int32_t* offsets,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
dim_t step,
int32_t* values_lengths,
int32_t* values_offsets) {
for (dim_t b = 0; b < batch_size; ++b) {
const auto length = lengths[b];
auto* batch_mask = mask + b * num_heads * num_queries;
for (dim_t i = 0; i < num_heads * num_queries; ++i) {
batch_mask[i] = (mask_future
? std::min(length,
int32_t((multi_query ? i / num_heads : i % num_queries) + 1))
: length);
const auto offset = offsets ? offsets[b] : 0;
const auto length = lengths[b] + int32_t(step) - offset;
const auto batch_offset = b * num_heads * num_queries;

for (dim_t i = batch_offset; i < batch_offset + num_heads * num_queries; ++i) {
if (mask_future) {
const int32_t time = step + (multi_query ? i / num_heads : i % num_queries);
values_lengths[i] = time < offset ? 0 : std::min(time - offset + 1, length);
} else {
values_lengths[i] = length;
}

if (values_offsets)
values_offsets[i] = offset;
}
}
}
Expand Down
78 changes: 50 additions & 28 deletions src/cuda/primitives.cu
Original file line number Diff line number Diff line change
Expand Up @@ -283,36 +283,58 @@ namespace ctranslate2 {
vocabulary_size);
}

__global__ void prepare_length_mask_kernel(const int32_t* lengths,
cuda::index_t num_heads,
cuda::index_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask) {
const auto length = lengths[blockIdx.x];
mask += blockIdx.x * num_heads * num_queries;
for (cuda::index_t i = threadIdx.x; i < num_heads * num_queries; i += blockDim.x)
mask[i] = (mask_future
? min(length, int32_t((multi_query ? i / num_heads : i % num_queries) + 1))
: length);
}

template<>
void primitives<Device::CUDA>::prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask) {
__global__ void prepare_mha_values_mask_kernel(const int32_t* lengths,
const int32_t* offsets,
cuda::index_t num_heads,
cuda::index_t num_queries,
bool mask_future,
bool multi_query,
cuda::index_t step,
int32_t* values_lengths,
int32_t* values_offsets) {
const auto offset = offsets ? offsets[blockIdx.x] : 0;
const auto length = lengths[blockIdx.x] + step - offset;
const auto batch_offset = blockIdx.x * num_heads * num_queries;

for (cuda::index_t i = batch_offset + threadIdx.x;
i < batch_offset + num_heads * num_queries;
i += blockDim.x) {

if (mask_future) {
const int32_t time = step + (multi_query ? i / num_heads : i % num_queries);
values_lengths[i] = time < offset ? 0 : min(time - offset + 1, length);
} else {
values_lengths[i] = length;
}

if (values_offsets)
values_offsets[i] = offset;
}
}

template<>
void primitives<Device::CUDA>::prepare_mha_values_mask(const int32_t* lengths,
const int32_t* offsets,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
dim_t step,
int32_t* values_lengths,
int32_t* values_offsets) {
const dim_t blocks = std::min(batch_size, cuda::max_blocks);
const dim_t threads = std::min(num_heads * num_queries, cuda::max_threads);
prepare_length_mask_kernel<<<blocks, threads, 0, cuda::get_cuda_stream()>>>(lengths,
num_heads,
num_queries,
mask_future,
multi_query,
mask);
prepare_mha_values_mask_kernel<<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
lengths,
offsets,
num_heads,
num_queries,
mask_future,
multi_query,
step,
values_lengths,
values_offsets);
}

template<>
Expand Down
Loading

0 comments on commit 25ef314

Please sign in to comment.