Skip to content

Commit

Permalink
Fix schemas for skip norms
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 17, 2024
1 parent e2761bf commit 02579ae
Show file tree
Hide file tree
Showing 36 changed files with 476 additions and 222 deletions.
75 changes: 51 additions & 24 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29389,7 +29389,20 @@ This version of the operator has been available since version 23 of the default

### <a name="SkipLayerNormalization-23"></a>**SkipLayerNormalization-23**</a>

Applies LayerNormalization to an expanded skip connection as described in the paper https://arxiv.org/pdf/2105.07205v1
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

LayerNorm is then applied to xSkip as follows:
```
output = LayerNormalization(xSkip)
```

#### Version

Expand All @@ -29398,36 +29411,36 @@ This version of the operator has been available since version 23 of the default
#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for layer normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs (3 - 5)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D input tensor with shape (hidden_size)</dd>
<dd>1D tensor representing scale input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>beta</tt> (optional) : T</dt>
<dd>1D skip tensor with shape (hidden_size)</dd>
<dd>1D tensor representing bias input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor with shape (hidden_size)</dd>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which layer normalization is applied.</dd>
</dl>

#### Outputs (1 - 4)
#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)Or 2D output tensor with shape (token_count, hidden_size)</dd>
<dt><tt>Mean</tt> (optional) : U</dt>
<dd>Saved mean used during training to speed up gradient computation</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists)with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).</dd>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints
Expand All @@ -29441,7 +29454,19 @@ This version of the operator has been available since version 23 of the default

### <a name="SkipRMSNormalization-23"></a>**SkipRMSNormalization-23**</a>

Applies RMSNormalization to an expanded skip connection similar to SkipLayerNormalization
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

RMSNorm is then applied to xSkip as follows:
```
output = RMSNormalization(xSkip)

#### Version

Expand All @@ -29450,32 +29475,34 @@ This version of the operator has been available since version 23 of the default
#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for rms normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs
#### Inputs (3 - 4)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D input tensor with shape (hidden_size)</dd>
<dt><tt>B</tt> : T</dt>
<dd>Bias tensor.</dd>
<dd>1D tensor representing scale input of rms normalization with shape of the spatial dimension along which rms normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which rms normalization is applied.</dd>
</dl>

#### Outputs (1 - 3)
#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)Or 2D output tensor with shape (token_count, hidden_size)</dd>
<dt><tt>Mean</tt> (optional) : U</dt>
<dd>Saved mean used during training to speed up gradient computation</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints
Expand Down
131 changes: 98 additions & 33 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -30031,7 +30031,20 @@ expect(node, inputs=[x], outputs=[y], name="test_size")

### <a name="SkipLayerNormalization"></a><a name="skiplayernormalization">**SkipLayerNormalization**</a>

Applies LayerNormalization to an expanded skip connection as described in the paper https://arxiv.org/pdf/2105.07205v1
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

LayerNorm is then applied to xSkip as follows:
```
output = LayerNormalization(xSkip)
```

#### Version

Expand All @@ -30040,36 +30053,36 @@ This version of the operator has been available since version 23 of the default
#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for layer normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs (3 - 5)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D input tensor with shape (hidden_size)</dd>
<dd>1D tensor representing scale input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>beta</tt> (optional) : T</dt>
<dd>1D skip tensor with shape (hidden_size)</dd>
<dd>1D tensor representing bias input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor with shape (hidden_size)</dd>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which layer normalization is applied.</dd>
</dl>

#### Outputs (1 - 4)
#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)Or 2D output tensor with shape (token_count, hidden_size)</dd>
<dt><tt>Mean</tt> (optional) : U</dt>
<dd>Saved mean used during training to speed up gradient computation</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists)with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).</dd>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints
Expand All @@ -30093,18 +30106,20 @@ skip = np.random.randn(4, 2).astype(np.float32)
gamma = np.random.randn(2).astype(np.float32)
beta = np.random.randn(2).astype(np.float32)
bias = np.random.randn(2).astype(np.float32)
y = _skip_layer_normalization(x, skip, gamma, beta, bias).astype(np.float32)
y, input_skip_bias_sum = _skip_layer_normalization(x, skip, gamma, beta, bias)
y.astype(np.float32)
input_skip_bias_sum.astype(np.float32)

node = onnx.helper.make_node(
"SkipLayerNormalization",
inputs=["x", "skip", "gamma", "beta", "bias"],
outputs=["y"],
outputs=["y", "input_skip_bias_sum"],
)

expect(
node,
inputs=[x, skip, gamma, beta, bias],
outputs=[y],
outputs=[y, input_skip_bias_sum],
name="test_skip_layer_normalization_2d_example",
)
```
Expand All @@ -30121,18 +30136,20 @@ skip = np.random.randn(3, 4, 2).astype(np.float32)
gamma = np.random.randn(2).astype(np.float32)
beta = np.random.randn(2).astype(np.float32)
bias = np.random.randn(2).astype(np.float32)
y = _skip_layer_normalization(x, skip, gamma, beta, bias).astype(np.float32)
y, input_skip_bias_sum = _skip_layer_normalization(x, skip, gamma, beta, bias)
y.astype(np.float32)
input_skip_bias_sum.astype(np.float32)

node = onnx.helper.make_node(
"SkipLayerNormalization",
inputs=["x", "skip", "gamma", "beta", "bias"],
outputs=["y"],
outputs=["y", "input_skip_bias_sum"],
)

expect(
node,
inputs=[x, skip, gamma, beta, bias],
outputs=[y],
outputs=[y, input_skip_bias_sum],
name="test_skip_layer_normalization_3d_example",
)
```
Expand All @@ -30150,29 +30167,75 @@ gamma = np.random.randn(2).astype(np.float32)
beta = np.random.randn(2).astype(np.float32)
bias = np.random.randn(2).astype(np.float32)
epsilon = 1e-2
y = _skip_layer_normalization(x, skip, gamma, beta, bias).astype(np.float32)
y, input_skip_bias_sum = _skip_layer_normalization(x, skip, gamma, beta, bias, epsilon=epsilon)
y.astype(np.float32)
input_skip_bias_sum.astype(np.float32)

node = onnx.helper.make_node(
"SkipLayerNormalization",
inputs=["x", "skip", "gamma", "beta", "bias"],
outputs=["y"],
outputs=["y", "input_skip_bias_sum"],
epsilon=epsilon,
)

expect(
node,
inputs=[x, skip, gamma, beta, bias],
outputs=[y],
outputs=[y, input_skip_bias_sum],
name="test_skip_layer_normalization_epsilon_example",
)
```

</details>


<details>
<summary>scaling_factor</summary>

```python
x = np.random.randn(3, 4, 2).astype(np.float32)
skip = np.random.randn(3, 4, 2).astype(np.float32)
gamma = np.random.randn(2).astype(np.float32)
beta = np.random.randn(2).astype(np.float32)
bias = np.random.randn(2).astype(np.float32)
scaling_factor = 3
y, input_skip_bias_sum = _skip_layer_normalization(x, skip, gamma, beta, bias, scaling_factor=scaling_factor)
y.astype(np.float32)
input_skip_bias_sum.astype(np.float32)

node = onnx.helper.make_node(
"SkipLayerNormalization",
inputs=["x", "skip", "gamma", "beta", "bias"],
outputs=["y", "input_skip_bias_sum"],
scaling_factor=scaling_factor,
)

expect(
node,
inputs=[x, skip, gamma, beta, bias],
outputs=[y, input_skip_bias_sum],
name="test_skip_layer_normalization_scaling_factor_example",
)
```

</details>


### <a name="SkipRMSNormalization"></a><a name="skiprmsnormalization">**SkipRMSNormalization**</a>

Applies RMSNormalization to an expanded skip connection similar to SkipLayerNormalization
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

RMSNorm is then applied to xSkip as follows:
```
output = RMSNormalization(xSkip)

#### Version

Expand All @@ -30181,32 +30244,34 @@ This version of the operator has been available since version 23 of the default
#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for rms normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs
#### Inputs (3 - 4)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)Or 2D input tensor with shape (token_count, hidden_size)</dd>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D input tensor with shape (hidden_size)</dd>
<dt><tt>B</tt> : T</dt>
<dd>Bias tensor.</dd>
<dd>1D tensor representing scale input of rms normalization with shape of the spatial dimension along which rms normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which rms normalization is applied.</dd>
</dl>

#### Outputs (1 - 3)
#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)Or 2D output tensor with shape (token_count, hidden_size)</dd>
<dt><tt>Mean</tt> (optional) : U</dt>
<dd>Saved mean used during training to speed up gradient computation</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints
Expand Down
Loading

0 comments on commit 02579ae

Please sign in to comment.