Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMA Model Optimization #18021

Merged
merged 39 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e74b899
Initial fusions and kernel changes for LLaMA
kunal-vaishnavi Aug 30, 2023
228de8c
Add rotary embeddings for LLaMA
kunal-vaishnavi Sep 10, 2023
dc16e16
Change input shapes and types for fused model
kunal-vaishnavi Sep 11, 2023
816f7e9
Add present kv to multi-head attention
kunal-vaishnavi Sep 15, 2023
5ce8e5a
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Sep 15, 2023
6669899
Update benchmark scripts
kunal-vaishnavi Sep 16, 2023
ed61ae4
Update inputs for optimized model
kunal-vaishnavi Sep 19, 2023
cdbd466
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Sep 20, 2023
becbd30
Add interleaved and non-interleaved rotary embeddings
kunal-vaishnavi Sep 29, 2023
eece5e8
Update rotary embeddings and export scripts
kunal-vaishnavi Oct 3, 2023
55d0554
Fix attention mask for HF version
kunal-vaishnavi Oct 4, 2023
37e6b5f
Modify rotary embeddings fusion for merged HF model
kunal-vaishnavi Oct 6, 2023
909f8e7
Add optimization passes after conversion
kunal-vaishnavi Oct 6, 2023
43f459b
Fix adding GQA to optimized model
kunal-vaishnavi Oct 7, 2023
4e2bf41
Add CPU implementation for rotary embeddings
kunal-vaishnavi Oct 7, 2023
2210c47
Add test cases
kunal-vaishnavi Oct 15, 2023
6f154e3
Clean up test cases
kunal-vaishnavi Oct 15, 2023
822c2e6
Fix initializer data in test case
kunal-vaishnavi Oct 15, 2023
cdf5536
Add merged export
kunal-vaishnavi Oct 16, 2023
52f5994
Remove logger warning
kunal-vaishnavi Oct 16, 2023
0d17656
Update docs
kunal-vaishnavi Oct 17, 2023
bcb5a32
Enable buffer sharing and int4 quantization
kunal-vaishnavi Oct 18, 2023
8ae9188
Fix inputs for buffer sharing
kunal-vaishnavi Oct 18, 2023
143d805
Remove extra print
kunal-vaishnavi Oct 18, 2023
f2b4644
Clean up code
kunal-vaishnavi Oct 18, 2023
d7bb72c
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Oct 18, 2023
8968bb3
Address PR feedback
kunal-vaishnavi Oct 19, 2023
84f7cc0
Add changes suggested by linters
kunal-vaishnavi Oct 19, 2023
99ec341
Fix min CUDA architecture
kunal-vaishnavi Oct 19, 2023
b76e2c2
Add graph input for GQA
kunal-vaishnavi Oct 19, 2023
edafef5
Fix GQA parity issue
kunal-vaishnavi Oct 20, 2023
7b82912
Add changes suggested by linter
kunal-vaishnavi Oct 20, 2023
a891398
Remove unreferenced parameter
kunal-vaishnavi Oct 20, 2023
716b725
Change rotary embedding test threshold
kunal-vaishnavi Oct 20, 2023
6b8698d
Add int4 CPU support
kunal-vaishnavi Oct 20, 2023
cc0199b
Add changes suggested by linters
kunal-vaishnavi Oct 20, 2023
e38ecb3
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Oct 21, 2023
e69c23b
Fix linter issue
kunal-vaishnavi Oct 21, 2023
d14d5bd
Fix CodeQL error
kunal-vaishnavi Oct 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
Expand Down Expand Up @@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
<dd>Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
Expand Down Expand Up @@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
that are multiplied to query and key before the inner product of query and key is taken.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
</dl>


### <a name="com.microsoft.SampleOp"></a><a name="com.microsoft.sampleop">**com.microsoft.SampleOp**</a>

Sample echo operator.
Expand Down
3 changes: 3 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -866,6 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
#include <iostream>

using onnxruntime::concurrency::ThreadPool;

Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
Expand All @@ -216,13 +217,21 @@ Status CheckInputs(const T* query,
} else if (mask_dims[0] == static_cast<int64_t>(3) * static_cast<int64_t>(batch_size) + static_cast<int64_t>(2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 3 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(sequence_length) &&
mask_dims[2] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}

if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
"Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}

Expand Down Expand Up @@ -257,7 +266,6 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
Expand Down
115 changes: 115 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using namespace onnxruntime::contrib::rotary_embedding_helper;

Check warning on line 10 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
RotaryEmbedding,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding<float>);

template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault<float>("scale", 1.0);
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

template <typename T>
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* position_ids = context->Input<Tensor>(1);
const Tensor* cos_cache = context->Input<Tensor>(2);
const Tensor* sin_cache = context->Input<Tensor>(3);

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
position_ids,
cos_cache,
sin_cache,
&parameters));

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}

const T* input_src = input->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
T* output_dest = output->MutableData<T>();

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

const int loop_len = batch_size * sequence_length * num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(pos_ids_data[0]) + s
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < head_size; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_head_size;
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
}
});

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class RotaryEmbedding final : public OpKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

protected:
float scale;
bool interleaved;
};

} // namespace contrib
} // namespace onnxruntime
Loading
Loading