Skip to content

Commit

Permalink
Add rotary embeddings schema
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 15, 2024
1 parent 2b504f8 commit f9f676e
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 1 deletion.
62 changes: 62 additions & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29026,6 +29026,68 @@ This version of the operator has been available since version 23 of the default
<dd>Constrain input and output types to all tensor types.</dd>
</dl>

### <a name="RotaryEmbedding-23"></a>**RotaryEmbedding-23**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
The positions are represented as rotation matrices that are multiplied to query and key
before the inner product of query and key is taken.

Rotary embeddings are defined using the below functions:

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rope(x, cos, sin, position_ids):
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed

#### Version

This version of the operator has been available since version 23 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>tensor with same shape as input.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
</dl>

### <a name="Scan-23"></a>**Scan-23**</a>

Scan can be used to iterate over one or more scan_input tensors,
Expand Down
64 changes: 64 additions & 0 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#ReduceLogSumExp">ReduceLogSumExp</a>|<a href="Changelog.md#ReduceLogSumExp-18">18</a>, <a href="Changelog.md#ReduceLogSumExp-13">13</a>, <a href="Changelog.md#ReduceLogSumExp-11">11</a>, <a href="Changelog.md#ReduceLogSumExp-1">1</a>|18|
|<a href="#ReduceSumSquare">ReduceSumSquare</a>|<a href="Changelog.md#ReduceSumSquare-18">18</a>, <a href="Changelog.md#ReduceSumSquare-13">13</a>, <a href="Changelog.md#ReduceSumSquare-11">11</a>, <a href="Changelog.md#ReduceSumSquare-1">1</a>|18|
|<a href="#Relu">Relu</a>|<a href="Changelog.md#Relu-14">14</a>, <a href="Changelog.md#Relu-13">13</a>, <a href="Changelog.md#Relu-6">6</a>, <a href="Changelog.md#Relu-1">1</a>|18|
|<a href="#RotaryEmbedding">RotaryEmbedding</a>|<a href="Changelog.md#RotaryEmbedding-23">23</a>|23|
|<a href="#Selu">Selu</a>|<a href="Changelog.md#Selu-22">22</a>, <a href="Changelog.md#Selu-6">6</a>, <a href="Changelog.md#Selu-1">1</a>|18|
|<a href="#SequenceMap">SequenceMap</a>|<a href="Changelog.md#SequenceMap-17">17</a>|17|
|<a href="#Shrink">Shrink</a>|<a href="Changelog.md#Shrink-9">9</a>|18|
Expand Down Expand Up @@ -27310,6 +27311,69 @@ expect(
</details>


### <a name="RotaryEmbedding"></a><a name="rotaryembedding">**RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
The positions are represented as rotation matrices that are multiplied to query and key
before the inner product of query and key is taken.

Rotary embeddings are defined using the below functions:

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rope(x, cos, sin, position_ids):
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed

#### Version

This version of the operator has been available since version 23 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>tensor with same shape as input.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
</dl>


### <a name="Round"></a><a name="round">**Round**</a>

Round takes one input Tensor and rounds the values, element-wise, meaning
Expand Down
5 changes: 4 additions & 1 deletion docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
Node tests have covered 179/192 (93.23%, 5 generators excluded) common operators.
Node tests have covered 179/193 (92.75%, 5 generators excluded) common operators.

Node tests have covered 0/0 (N/A) experimental operators.

Expand Down Expand Up @@ -24737,6 +24737,9 @@ expect(node, inputs=[x, y], outputs=[z], name="test_xor_bcast4v4d")
### RandomUniformLike (random generator operator)


### RotaryEmbedding (call for test cases)


### SequenceAt (call for test cases)


Expand Down
96 changes: 96 additions & 0 deletions onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2832,4 +2832,100 @@ ONNX_OPERATOR_SET_SCHEMA(
schema.BuildFunction(functionProto);
return true;
}));

static const char* RotaryEmbedding_ver23_doc = R"DOC(
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
The positions are represented as rotation matrices that are multiplied to query and key
before the inner product of query and key is taken.

Rotary embeddings are defined using the below functions:

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rope(x, cos, sin, position_ids):
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
RotaryEmbedding,
23,
OpSchema()
.SetDoc(RotaryEmbedding_ver23_doc)
.Attr("interleaved",
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
"T")
.Input(1,
"position_ids",
"1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)",
"M")
.Input(2,
"cos_cache",
"2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
"T")
.Input(3,
"sin_cache",
"2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
"T")
.Output(0,
"output",
"tensor with same shape as input.",
"T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 0);
})
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {

auto mktensor = [](int64_t val) -> ONNX_NAMESPACE::TensorProto {
auto tp = ONNX_NAMESPACE::ToTensor(std::vector<int64_t>{val});
tp.add_dims(1);
return tp;
};

FunctionBuilder builder(functionProto);
builder.Add("SqueezeDims = Constant <value_ints = [0, 1]> ()")
.Add("CosCacheSqueezed = Squeeze(cos_cache, SqueezeDims)")
.Add("SinCacheSqueezed = Squeeze(sin_cache, SqueezeDims)")
.Add("UnqueezeDims = Constant <value_ints = [0]> ()")
.Add("CosCacheGather = Gather(CosCacheSqueezed, position_ids)")
.Add("CosCacheUnsqueezed = Unsqueeze(cos_cache, UnsqueezeDims)")
.Add("CosCacheGather = Gather(CosCacheSqueezed, position_ids)")
.Add("SinCacheUnsqueezed = Unsqueeze(sin_cache, UnsqueezeDims)");

builder.Add("Shape = Shape (input)") // shape of input tensor: 1D tensor
.Add("One1D = Constant()", "value", mktensor(1)) // [1] : 1D tensor
.Add("InputToRotate = Gather(Shape, Zero1D)") // 1D tensor
.Add("RotateEmbedDim = Size(InputToRotate)") // scalar
.Add("Two1D = Constant()", "value", mktensor(2)) // [2] : 1D tensor
.Add("RotateEmbedDimHalf = Div(InputToRotate, Two1D)")
.Add("One1D = Constant()", "value", mktensor(1)) // [1] : 1D tensor
.Add("InputFirstHalf = Slice (input, Zero1D, RotateEmbedDimHalf, Axis1D)")
.Add("InputSecondHalf = Slice (input, RotateEmbedDimHalf, RotateEmbedDim, Axis1D)")
.Add("NegInputSecondHalf = Neg(InputSecondHalf)")
.Add("ConcatInput = Concat <axis = -1> (NegInputFirstHalf, InputFirstHalf)");

builder.Add("CosMultiplied = Mul(input, CosCacheUnsqueezed)")
.Add("SinMultiplied = Mul(ConcatInput, SinCacheUnsqueezed)")
.Add("output = Add(CosMultiplied, SinMultiplied)");

schema.BuildFunction(functionProto);
return true;
}));
} // namespace ONNX_NAMESPACE
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Loop);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Reshape);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, RotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Scan);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Shape);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Size);
Expand All @@ -1326,6 +1327,7 @@ class OpSet_Onnx_ver23 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Pad)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Reshape)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, RotaryEmbedding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Scan)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Shape)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Size)>());
Expand Down

0 comments on commit f9f676e

Please sign in to comment.