Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMA Model Optimization #18021

Merged
merged 39 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e74b899
Initial fusions and kernel changes for LLaMA
kunal-vaishnavi Aug 30, 2023
228de8c
Add rotary embeddings for LLaMA
kunal-vaishnavi Sep 10, 2023
dc16e16
Change input shapes and types for fused model
kunal-vaishnavi Sep 11, 2023
816f7e9
Add present kv to multi-head attention
kunal-vaishnavi Sep 15, 2023
5ce8e5a
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Sep 15, 2023
6669899
Update benchmark scripts
kunal-vaishnavi Sep 16, 2023
ed61ae4
Update inputs for optimized model
kunal-vaishnavi Sep 19, 2023
cdbd466
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Sep 20, 2023
becbd30
Add interleaved and non-interleaved rotary embeddings
kunal-vaishnavi Sep 29, 2023
eece5e8
Update rotary embeddings and export scripts
kunal-vaishnavi Oct 3, 2023
55d0554
Fix attention mask for HF version
kunal-vaishnavi Oct 4, 2023
37e6b5f
Modify rotary embeddings fusion for merged HF model
kunal-vaishnavi Oct 6, 2023
909f8e7
Add optimization passes after conversion
kunal-vaishnavi Oct 6, 2023
43f459b
Fix adding GQA to optimized model
kunal-vaishnavi Oct 7, 2023
4e2bf41
Add CPU implementation for rotary embeddings
kunal-vaishnavi Oct 7, 2023
2210c47
Add test cases
kunal-vaishnavi Oct 15, 2023
6f154e3
Clean up test cases
kunal-vaishnavi Oct 15, 2023
822c2e6
Fix initializer data in test case
kunal-vaishnavi Oct 15, 2023
cdf5536
Add merged export
kunal-vaishnavi Oct 16, 2023
52f5994
Remove logger warning
kunal-vaishnavi Oct 16, 2023
0d17656
Update docs
kunal-vaishnavi Oct 17, 2023
bcb5a32
Enable buffer sharing and int4 quantization
kunal-vaishnavi Oct 18, 2023
8ae9188
Fix inputs for buffer sharing
kunal-vaishnavi Oct 18, 2023
143d805
Remove extra print
kunal-vaishnavi Oct 18, 2023
f2b4644
Clean up code
kunal-vaishnavi Oct 18, 2023
d7bb72c
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Oct 18, 2023
8968bb3
Address PR feedback
kunal-vaishnavi Oct 19, 2023
84f7cc0
Add changes suggested by linters
kunal-vaishnavi Oct 19, 2023
99ec341
Fix min CUDA architecture
kunal-vaishnavi Oct 19, 2023
b76e2c2
Add graph input for GQA
kunal-vaishnavi Oct 19, 2023
edafef5
Fix GQA parity issue
kunal-vaishnavi Oct 20, 2023
7b82912
Add changes suggested by linter
kunal-vaishnavi Oct 20, 2023
a891398
Remove unreferenced parameter
kunal-vaishnavi Oct 20, 2023
716b725
Change rotary embedding test threshold
kunal-vaishnavi Oct 20, 2023
6b8698d
Add int4 CPU support
kunal-vaishnavi Oct 20, 2023
cc0199b
Add changes suggested by linters
kunal-vaishnavi Oct 20, 2023
e38ecb3
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi Oct 21, 2023
e69c23b
Fix linter issue
kunal-vaishnavi Oct 21, 2023
d14d5bd
Fix CodeQL error
kunal-vaishnavi Oct 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
#include <iostream>

using onnxruntime::concurrency::ThreadPool;

Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
Expand All @@ -218,11 +219,15 @@
}
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(total_sequence_length)) {

Check warning on line 222 in onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h:222: Lines should be <= 120 characters long [whitespace/line_length] [2]
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 3 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(sequence_length) && mask_dims[2] == static_cast<int64_t>(total_sequence_length)) {

Check warning on line 224 in onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h:224: Lines should be <= 120 characters long [whitespace/line_length] [2]
mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}

if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
"Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}

Expand Down Expand Up @@ -257,7 +262,6 @@
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "rotary_embedding.h"

Check warning on line 4 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:4: Include the directory when naming header files [build/include_subdir] [4]
#include "rotary_embedding_helper.h"

Check warning on line 5 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:5: Include the directory when naming header files [build/include_subdir] [4]

#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using namespace onnxruntime::contrib::rotary_embedding_helper;

Check warning on line 10 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
RotaryEmbedding,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding<float>);

template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault<float>("scale", 1.0);
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

template <typename T>
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* position_ids = context->Input<Tensor>(1);
const Tensor* cos_cache = context->Input<Tensor>(2);
const Tensor* sin_cache = context->Input<Tensor>(3);

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
position_ids,
cos_cache,
sin_cache,
&parameters));

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}

const T* input_src = input->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
T* output_dest = output->MutableData<T>();

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

const int loop_len = batch_size * sequence_length * num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0) ? static_cast<int>(pos_ids_data[0]) : static_cast<int>(pos_ids_data[b * sequence_length + s]);

Check warning on line 86 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:86: Lines should be <= 120 characters long [whitespace/line_length] [2]
const int cache_offset = (position_ids_format == 0) ? (position_id + s) * half_head_size : position_id * half_head_size;

Check warning on line 87 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:87: Lines should be <= 120 characters long [whitespace/line_length] [2]
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < head_size; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_head_size;
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
}
});

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class RotaryEmbedding final : public OpKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

protected:
float scale;
bool interleaved;
};

} // namespace contrib
} // namespace onnxruntime
120 changes: 120 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/common.h"

namespace onnxruntime {
namespace contrib {
namespace rotary_embedding_helper {

// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
int batch_size; // Batch size used by input
int sequence_length; // Sequence length used by input
int hidden_size; // Hidden size used by input
int head_size; // Head size used by cos/sin cache * 2
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
};

template <typename T>
Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
// cos cache : (max_sequence_length, head_size / 2)
// sin cache : (max_sequence_length, head_size / 2)

// Check input
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
input_dims.size());
}
// Check position_ids
const auto& position_ids_dims = position_ids->Shape().GetDims();
if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ",

Check warning on line 43 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h:43: Lines should be <= 120 characters long [whitespace/line_length] [2]
position_ids_dims.size());
}
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
cos_cache_dims.size());
}
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape");

Check warning on line 58 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h:58: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

// Get attributes from inputs
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = static_cast<int>(cos_cache_dims[1]) * 2;
int num_heads = hidden_size / head_size;
int position_ids_format = -1;

// Check position_ids input shapes
if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
if (batch_size != static_cast<int>(position_ids_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size batch_size, got ",

Check warning on line 73 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h:73: Lines should be <= 120 characters long [whitespace/line_length] [2]
position_ids_dims[0]);
}
if (sequence_length != static_cast<int>(position_ids_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size sequence_length, got ",
position_ids_dims[1]);
}
position_ids_format = 1;
} else {
position_ids_format = 0;
}
// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as max_sequence_length, got ",
cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as head_size / 2, got ",
cos_cache_dims[1]);
}
// Check sin_cache input shapes
if (max_sequence_length != static_cast<int>(sin_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as max_sequence_length, got ",
sin_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(sin_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as head_size / 2, got ",
sin_cache_dims[1]);
}

// Set rotary parameters
if (parameters != nullptr) {
RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
}

return Status::OK();
}

} // namespace rotary_embedding_helper
} // namespace contrib
} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
Expand Down Expand Up @@ -124,6 +125,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);

Expand Down Expand Up @@ -253,6 +256,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
Expand Down Expand Up @@ -299,6 +303,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,

Expand Down
29 changes: 22 additions & 7 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,29 @@ namespace contrib {
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T>);
SkipLayerNorm<T, false>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SkipSimplifiedLayerNormalization, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T, true>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)

template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}

template <typename T>
Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = p_ctx->Input<Tensor>(1);
const Tensor* gamma = p_ctx->Input<Tensor>(2);
Expand Down Expand Up @@ -102,10 +111,16 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
}

mean = mean / hidden_size;
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon_);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}

for (int64_t h = 0; h < hidden_size; h++) {
if (nullptr == beta_data) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * gamma_data[h];
} else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace onnxruntime {
namespace contrib {

template <typename T>
template <typename T, bool simplified>
class SkipLayerNorm final : public OpKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Expand Down
Loading
Loading