Skip to content

Commit

Permalink
avoid double reshape logic
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 24, 2024
1 parent c5ef395 commit 6d55f6a
Show file tree
Hide file tree
Showing 80 changed files with 36 additions and 77 deletions.
11 changes: 3 additions & 8 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -28979,8 +28979,7 @@ This version of the operator has been available since version 23 of the default
### <a name="RMSNormalization-23"></a>**RMSNormalization-23**</a>

This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The first stage is standardization, which makes the
normalized elements have zero mean and unit variances. The root mean squared norm is taken over the last D dimensions,
The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
Expand All @@ -28993,7 +28992,7 @@ This version of the operator has been available since version 23 of the default
Normalized = Div(X, SqrtRMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
The second stage then scales and shifts the outcome of the first stage using:
The second stage then scales the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
Expand All @@ -29015,8 +29014,6 @@ This version of the operator has been available since version 23 of the default
<dd>The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>stash_type</tt> : int (default is 1)</dt>
<dd>Type of Mean and InvStdDev. This also specifies stage one's computation precision.</dd>
</dl>

#### Inputs
Expand All @@ -29025,7 +29022,7 @@ This version of the operator has been available since version 23 of the default
<dt><tt>X</tt> : T</dt>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels. The root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.</dd>
<dt><tt>scale</tt> : V</dt>
<dd>Scale tensor.</dd>
<dd>Scale tensor. Shape is the normalized shape ([axis, .., Dn]) or a scalar (which will be broadcasted to the normalized shape.</dd>
</dl>

#### Outputs
Expand All @@ -29040,8 +29037,6 @@ This version of the operator has been available since version 23 of the default
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input X type to float tensors.</dd>
<dt><tt>U</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain mean and inv_std_var to be float tensors.</dd>
<dt><tt>V</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain output Y and scale type to float tensors.</dd>
</dl>
Expand Down
11 changes: 3 additions & 8 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -21104,8 +21104,7 @@ expect(
### <a name="RMSNormalization"></a><a name="rmsnormalization">**RMSNormalization**</a>

This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The first stage is standardization, which makes the
normalized elements have zero mean and unit variances. The root mean squared norm is taken over the last D dimensions,
The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
Expand All @@ -21118,7 +21117,7 @@ expect(
Normalized = Div(X, SqrtRMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
The second stage then scales and shifts the outcome of the first stage using:
The second stage then scales the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
Expand All @@ -21140,8 +21139,6 @@ This version of the operator has been available since version 23 of the default
<dd>The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>stash_type</tt> : int (default is 1)</dt>
<dd>Type of Mean and InvStdDev. This also specifies stage one's computation precision.</dd>
</dl>

#### Inputs
Expand All @@ -21150,7 +21147,7 @@ This version of the operator has been available since version 23 of the default
<dt><tt>X</tt> : T</dt>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels. The root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.</dd>
<dt><tt>scale</tt> : V</dt>
<dd>Scale tensor.</dd>
<dd>Scale tensor. Shape is the normalized shape ([axis, .., Dn]) or a scalar (which will be broadcasted to the normalized shape.</dd>
</dl>

#### Outputs
Expand All @@ -21165,8 +21162,6 @@ This version of the operator has been available since version 23 of the default
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input X type to float tensors.</dd>
<dt><tt>U</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain mean and inv_std_var to be float tensors.</dd>
<dt><tt>V</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain output Y and scale type to float tensors.</dd>
</dl>
Expand Down
22 changes: 7 additions & 15 deletions onnx/backend/test/case/node/rmsnormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,18 @@ def _rms_normalization(X, W, axis=-1, epsilon=1e-5): # type: ignore
# which means the last axis.
axis = axis + rank

# Parameter used to convert N-D tensor RMS normalization to equivalent 2-D matirx operations.
row_number = np.prod(shape[:axis]).astype(np.int64)
col_number = np.prod(shape[axis:]).astype(np.int64)

# After reshaping input tensor X into a matrix,
# RMS normalization is equivalent to conducting
# standardization on each column vector (s.t. each
# column has zero mean and unit variance).
x_mat = np.reshape(X, (row_number, col_number))
# This computes RMS for every x_mat's column.
x_squared = np.power(x_mat, 2)
x_squared_mean = np.sum(x_squared, axis=1, keepdims=True) / col_number
x_squared = np.power(X, 2)
x_squared_mean = np.mean(x_squared, axis=tuple(range(axis, len(shape))), keepdims=True)
rms = np.sqrt(x_squared_mean)
# epsilon adjustment to avoid divide-by-zero.
rms_plus_epsilon = rms + epsilon
rms_plus_epsilon_sqrt = np.sqrt(rms_plus_epsilon)
rms_reciprocal = np.reciprocal(rms_plus_epsilon_sqrt)
# Standardization step. y_mat is zero-mean and unit-variance.
y_mat = x_mat * rms_reciprocal
# Apply affine transform on normalization outcome. W is linear coefficient.
Y = np.reshape(y_mat, shape) * W

y_mat = X * rms_reciprocal
# W is linear coefficient.
Y = y_mat * W

return Y

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
 BYJ`'W-+��?:�cX-ߦ?6O�N��?�e����?�'�z@b��m���?.0)w�?c5Q�\�?��:c��?i �{q��?9L��>�?)o��
BYJ0Y��?k�6=u�>��/?��#@�]<>I��>���=�w>�#|>���=��}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
 BYJ`'W-+��?:�cX-ߦ?6O�N��?�e����?�'�z@b��m���?.0)w�?c5Q�\�?��:c��?i �{q��?9L��>�?)o��
BYJ0Y��?k�6=u�>��/?��#@�]<>I��>���=�w>�#|>���=��}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
 BYJ`�i=yc@�^U�"�(���� �?F�SfS�ؿ��S�@��`�e��C�g��I�?�NU�/㪿�|%녵?F�E����5�Z���� %��˳�?
BYJ0�k@��N(=�
ž (@.����N
>~W�X/�=�?��T<�^�U>
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
 BYJ`�i=yc@�^U�"�(���� �?F�SfS�ؿ��S�@��`�e��C�g��I�?�NU�/㪿�|%녵?F�E����5�Z���� %��˳�?
BYJ0�k@��N(=�
ž (@.����N
>~W�X/�=�?��T<�^�U>
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
 BYJ�q�A��, @��(�qÿ�O%��|忞�N��@īܥNM@���8k����p�eH�?��bҠ�?;�4*�)ǿ=��v۹�����(�?����:�?Թ-(緻���[��?ٯ:p �?5�~�1��?7W��c �p�����?Y��J���?sk-q�l�� q��z�aV�?�0}���q���������
��J����nqń�!�?^�3W��?�� m��?�:����?
BYJxgI@����+�}@uj"@Ykֿ-CJ?��=L9���νF�=�֡?8�ݽ�r�=�[�>�!?c��_>��?l���N����?�M$��m��X@��Ȁ��=���<i�O?p��?
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
 BYJ�q�A��, @��(�qÿ�O%��|忞�N��@īܥNM@���8k����p�eH�?��bҠ�?;�4*�)ǿ=��v۹�����(�?����:�?Թ-(緻���[��?ٯ:p �?5�~�1��?7W��c �p�����?Y��J���?sk-q�l�� q��z�aV�?�0}���q���������
��J����nqń�!�?^�3W��?�� m��?�:����?
BYJxgI@����+�}@uj"@Ykֿ-CJ?��=L9���νF�=�֡?8�ݽ�r�=�[�>�!?c��_>��?l���N����?�M$��m��X@��Ȁ��=���<i�O?p��?
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
51 changes: 13 additions & 38 deletions onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2835,8 +2835,7 @@ ONNX_OPERATOR_SET_SCHEMA(

static const char* RMSNormalization_ver23_doc = R"DOC(
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The first stage is standardization, which makes the
normalized elements have zero mean and unit variances. The root mean squared norm is taken over the last D dimensions,
The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
Expand All @@ -2849,7 +2848,7 @@ static const char* RMSNormalization_ver23_doc = R"DOC(
Normalized = Div(X, SqrtRMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
The second stage then scales and shifts the outcome of the first stage using:
The second stage then scales the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
Expand All @@ -2874,11 +2873,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"epsilon",
"The epsilon value to use to avoid division by zero.",
AttributeProto::FLOAT, 1e-5f)
.Attr(
"stash_type",
"Type of Mean and InvStdDev. This also specifies stage one's computation precision.",
AttributeProto::INT,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))
.AllowUncheckedAttributes()
.Input(0,
"X",
Expand All @@ -2889,7 +2883,8 @@ ONNX_OPERATOR_SET_SCHEMA(
"T")
.Input(1,
"scale",
"Scale tensor.",
"Scale tensor. Shape is the normalized shape ([axis, .., Dn]) or a scalar (which will be broadcasted to "
"the normalized shape.",
"V")
.Output(0,
"Y",
Expand All @@ -2899,10 +2894,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input X type to float tensors.")
.TypeConstraint(
"U",
{"tensor(float)", "tensor(double)"},
"Constrain mean and inv_std_var to be float tensors.")
.TypeConstraint(
"V",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
Expand Down Expand Up @@ -2960,44 +2951,28 @@ ONNX_OPERATOR_SET_SCHEMA(
tp.add_dims(1);
return tp;
};
// The treatment of "axis" is different in "RMSNormalization" and in Reduction operations.
// This complicates the function definition, requiring reshaping inputs/outputs.
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
// This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]]
// Normalization is applied to the second dimension.
// Output Y has same shape as X
// Outputs InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1]

FunctionBuilder builder(functionProto);
builder.Const("FloatEpsilon", ToTensor<float>(epsilon))
.Add("Epsilon = Cast (FloatEpsilon)", "to", U)
.Add("XShape = Shape (X)") // shape of input tensor: 1D tensor
.Add("Rank = Size (XShape)") // rank of input tensor: scalar
.Add("Zero1D = Constant()", "value", mktensor(0)) // [0] : 1D tensor
.Add("Axis1D = Constant()", "value", mktensor(axis)) // [axis] : 1D tensor
.Add("PrefixShape = Slice (XShape, Zero1D, Axis1D)") // [d[0], ..., d[axis-1]]
.Add(
axis >= 0 // number of axes that are reduced =
? "NumReducedAxes = Sub (Rank, Axis1D)" // [rank - axis]: 1D tensor
: "NumReducedAxes = Neg (Axis1D)") // [-axis] : 1D tensor
.Add(
"SuffixShape = ConstantOfShape (NumReducedAxes)",
"value",
mktensor(1)) // [1, ..., 1] for reduced axes
.Add("ReducedShape = Concat <axis = 0> (PrefixShape, SuffixShape)") // [d[0], ..., d[axis-1], 1, ..., 1]
.Add("X2D = Flatten (X)", "axis", axis)
.Add("XU = Cast (X2D)", "to", U);
builder.Add("Axes_1 = Constant()", "value", mktensor(1))
.Add("XSquared = Mul (XU, XU)")
.Add("XSquaredMean = ReduceMean (XSquared, Axes_1)")
? "PosAxis1D = Identity (Axis1D)" // [axis]: 1D tensor
: "PosAxis1D = Sub (Rank, Axis1D)") // [rank - axis] : 1D tensor
.Add("ReduceAxes = Range(PosAxis1D, Rank)")
.Add("XU = Cast (X)", "to", U);
builder.Add("XSquared = Mul (XU, XU)")
.Add("XSquaredMean = ReduceMean (XSquared, ReduceAxes)")
.Add("RMS = Sqrt (XSquaredMean)")
.Add("RMSPlusEpsilon = Add (RMS, Epsilon)")
.Add("SqrtRMS = Sqrt (RMSPlusEpsilon)")
.Add("Normalized = Div (XU, SqrtRMS)")
.Add("NormalizedT = Cast (Normalized)", "to", T)
.Add("Scale2D = Flatten <axis = 0> (Scale)")
.Add("Scaled = Mul (NormalizedT, Scale2D)");
.Add("NormalizedT = Cast (Normalized)", "to", T);
builder.Add("Y = Mul (NormalizedT, Scale)");

builder.Add("Y = Reshape (Scaled, XShape)");
schema.BuildFunction(functionProto);
return true;
}));
Expand Down

0 comments on commit 6d55f6a

Please sign in to comment.