Skip to content

Commit

Permalink
[DML EP] Add NCHW and float16 gamma/beta support for GroupNorm (#16814)
Browse files Browse the repository at this point in the history
This will remove transposes that are non needed in the DML kernel. To
keep backward compatiblity, the default behavior is to set NHWC when no
attribute is set.
  • Loading branch information
PatriceVignola authored Jul 26, 2023
1 parent 39fca22 commit 6499301
Show file tree
Hide file tree
Showing 13 changed files with 587 additions and 121 deletions.
6 changes: 4 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2176,6 +2176,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>activation</tt> : int (required)</dt>
<dd>Activation after group normalization: 0 for None, 1 for Swish</dd>
<dt><tt>channels_last</tt> : int</dt>
<dd>1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>The epsilon value to use to avoid division by zero</dd>
<dt><tt>groups</tt> : int (required)</dt>
Expand All @@ -2186,7 +2188,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
<dd>Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
<dt><tt>gamma</tt> : M</dt>
<dd>1D gamma tensor for normalization with shape (C), where C is number of channels</dd>
<dt><tt>beta</tt> : M</dt>
Expand All @@ -2205,7 +2207,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X and output Y types to float tensors.</dd>
<dt><tt>M</tt> : tensor(float)</dt>
<dt><tt>M</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain gamma and beta to float tensors.</dd>
</dl>

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
use_swish_activation_ = (activation == 1);

channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}

Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
Expand All @@ -74,6 +76,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
const Tensor* beta = context->Input<Tensor>(2);
Tensor* output = context->Output(0, input->Shape());

if (!channels_last_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"only the channels_last layout is supported");
}

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GroupNorm final : public CudaKernel {
bool use_swish_activation_;
float epsilon_;
int num_groups_;
bool channels_last_;
};

} // namespace cuda
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
use_swish_activation_ = (activation == 1);

channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}

Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
Expand All @@ -76,6 +78,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
const Tensor* beta = context->Input<Tensor>(2);
Tensor* output = context->Output(0, input->Shape());

if (!channels_last_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"only the channels_last layout is supported");
}

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Attr("activation",
"Activation after group normalization: 0 for None, 1 for Swish",
AttributeProto::INT)
.Attr("channels_last",
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0,
"X",
"Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data",
"Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data",
"T")
.Input(1,
"gamma",
Expand All @@ -61,7 +65,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"The output tensor of the same shape as X",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.")
.TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.")
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));

constexpr const char* BiasSplitGelu_ver1_doc = R"DOC(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class DmlOperatorGroupNorm : public DmlOperator
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

const bool channelsLast = kernelCreationContext.GetOptionalAttribute<bool>(AttrName::ChannelsLast, true);
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
const bool activation = gsl::narrow_cast<bool>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Activation));
const uint32_t groups = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Groups));
Expand All @@ -36,19 +37,32 @@ class DmlOperatorGroupNorm : public DmlOperator

// Data is in NHWC format
const uint32_t batch = inputTensorShape[0];
const uint32_t height = inputTensorShape[1];
const uint32_t width = inputTensorShape[2];
const uint32_t channels = inputTensorShape[3];
const uint32_t height = channelsLast ? inputTensorShape[1] : inputTensorShape[2];
const uint32_t width = channelsLast ? inputTensorShape[2] : inputTensorShape[3];
const uint32_t channels = channelsLast ? inputTensorShape[3] : inputTensorShape[1];
ML_CHECK_VALID_ARGUMENT(gammaTensorShape[0] == channels);
ML_CHECK_VALID_ARGUMENT(betaTensorShape[0] == channels);
ML_CHECK_VALID_ARGUMENT(channels % groups == 0);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[1].GetDmlDataType() == m_inputTensorDescs[2].GetDmlDataType());
const uint32_t channelsPerGroup = channels / groups;

// 1. Reshape the input from [batch, height, width, channels] to [batch, height * width, groups, channelsPerGroup]
// 2. Stride the reshaped input from [batch, height * width, groups, channelsPerGroup] to [batch, groups, channelsPerGroup, height * width]
const std::array<uint32_t, 4> inputShape = {batch, groups, channelsPerGroup, height * width};
const std::array<uint32_t, 4> inputStrides = {channelsPerGroup * height * width * groups, channelsPerGroup, 1, groups * channelsPerGroup};
std::array<uint32_t, 4> inputShape;
std::array<uint32_t, 4> inputStrides;

if (channelsLast)
{
// 1. Reshape the input from [batch, height, width, channels] to [batch, height * width, groups, channelsPerGroup]
// 2. Stride the reshaped input from [batch, height * width, groups, channelsPerGroup] to [batch, groups, channelsPerGroup, height * width]
inputShape = {batch, groups, channelsPerGroup, height * width};
inputStrides = {channelsPerGroup * height * width * groups, channelsPerGroup, 1, groups * channelsPerGroup};
}
else
{
// Reshape the input from [batch, channels, height, width] to [batch, groups, channelsPerGroup, height * width]
inputShape = {batch, groups, channelsPerGroup, height * width};
inputStrides = {groups * channelsPerGroup * height * width, channelsPerGroup * height * width, height * width, 1};
}

TensorDesc inputTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), inputShape, inputStrides);
const DML_TENSOR_DESC inputDmlTensorDesc = inputTensorDesc.GetDmlDesc();

Expand Down Expand Up @@ -119,24 +133,36 @@ class DmlOperatorGroupNorm : public DmlOperator

uint32_t currentNodeIndex = 0;

const uint32_t transposeInputIndex = currentNodeIndex++;
opDescs.push_back(&dmlTransposeInputDesc);

const uint32_t mvnNodeIndex = currentNodeIndex++;
opDescs.push_back(&dmlMvnDesc);

DML_INPUT_GRAPH_EDGE_DESC inputEdge{};
inputEdge.GraphInputIndex = 0;
inputEdge.ToNodeIndex = transposeInputIndex;
inputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputEdge);

DML_INTERMEDIATE_GRAPH_EDGE_DESC transposeInputToMvnEdge = {};
transposeInputToMvnEdge.FromNodeIndex = transposeInputIndex;
transposeInputToMvnEdge.FromNodeOutputIndex = 0;
transposeInputToMvnEdge.ToNodeIndex = mvnNodeIndex;
transposeInputToMvnEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(transposeInputToMvnEdge);
// We only need a transpose the input when the layout is NHWC
if (channelsLast)
{
const uint32_t transposeInputIndex = currentNodeIndex++;
opDescs.push_back(&dmlTransposeInputDesc);

DML_INPUT_GRAPH_EDGE_DESC inputEdge{};
inputEdge.GraphInputIndex = 0;
inputEdge.ToNodeIndex = transposeInputIndex;
inputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputEdge);

DML_INTERMEDIATE_GRAPH_EDGE_DESC transposeInputToMvnEdge = {};
transposeInputToMvnEdge.FromNodeIndex = transposeInputIndex;
transposeInputToMvnEdge.FromNodeOutputIndex = 0;
transposeInputToMvnEdge.ToNodeIndex = mvnNodeIndex;
transposeInputToMvnEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(transposeInputToMvnEdge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC inputEdge{};
inputEdge.GraphInputIndex = 0;
inputEdge.ToNodeIndex = mvnNodeIndex;
inputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputEdge);
}

if (gammaBetaCastNeeded)
{
Expand Down Expand Up @@ -224,21 +250,33 @@ class DmlOperatorGroupNorm : public DmlOperator
}
else
{
const uint32_t transposeOutputNodeIndex = currentNodeIndex++;
opDescs.push_back(&dmlTransposeOutputDesc);

DML_INTERMEDIATE_GRAPH_EDGE_DESC mvnToTransposeOutputEdge = {};
mvnToTransposeOutputEdge.FromNodeIndex = mvnNodeIndex;
mvnToTransposeOutputEdge.FromNodeOutputIndex = 0;
mvnToTransposeOutputEdge.ToNodeIndex = transposeOutputNodeIndex;
mvnToTransposeOutputEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(mvnToTransposeOutputEdge);

DML_OUTPUT_GRAPH_EDGE_DESC transposeOutputToOutputEdge{};
transposeOutputToOutputEdge.FromNodeIndex = transposeOutputNodeIndex;
transposeOutputToOutputEdge.FromNodeOutputIndex = 0;
transposeOutputToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(transposeOutputToOutputEdge);
if (channelsLast)
{
// We only need a transpose the output when the layout is NHWC
const uint32_t transposeOutputNodeIndex = currentNodeIndex++;
opDescs.push_back(&dmlTransposeOutputDesc);

DML_INTERMEDIATE_GRAPH_EDGE_DESC mvnToTransposeOutputEdge = {};
mvnToTransposeOutputEdge.FromNodeIndex = mvnNodeIndex;
mvnToTransposeOutputEdge.FromNodeOutputIndex = 0;
mvnToTransposeOutputEdge.ToNodeIndex = transposeOutputNodeIndex;
mvnToTransposeOutputEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(mvnToTransposeOutputEdge);

DML_OUTPUT_GRAPH_EDGE_DESC transposeOutputToOutputEdge{};
transposeOutputToOutputEdge.FromNodeIndex = transposeOutputNodeIndex;
transposeOutputToOutputEdge.FromNodeOutputIndex = 0;
transposeOutputToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(transposeOutputToOutputEdge);
}
else
{
DML_OUTPUT_GRAPH_EDGE_DESC mvnToOutputEdge{};
mvnToOutputEdge.FromNodeIndex = mvnNodeIndex;
mvnToOutputEdge.FromNodeOutputIndex = 0;
mvnToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(mvnToOutputEdge);
}
}

MLOperatorGraphDesc operatorGraphDesc = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace AttrName
static constexpr const char* BlockSize = "blocksize";
static constexpr const char* Border = "border";
static constexpr const char* Broadcast = "broadcast";
static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* CeilMode = "ceil_mode";
static constexpr const char* Clip = "clip";
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def convert_float_to_float16(
op_block_list=None,
node_block_list=None,
force_fp16_initializers=False,
force_fp16_inputs=None,
):
"""Convert tensor float type in the input ONNX model to tensor float16.
Expand All @@ -190,6 +191,8 @@ def convert_float_to_float16(
node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
force_fp16_initializers(bool): force converting all float initializers to float16.
Default to false, which will convert only the one needed to avoid precision loss.
force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if
this script's preference it to keep them in float32.
Raises:
ValueError: input type is not ModelProto.
Expand All @@ -201,6 +204,8 @@ def convert_float_to_float16(
), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05"
assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504"

force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs

if isinstance(model, str):
model_path = model
if version.parse(onnx.__version__) >= version.parse("1.8.0") and not disable_shape_infer:
Expand Down Expand Up @@ -327,7 +332,10 @@ def convert_float_to_float16(
for i, input_name in enumerate(n.input):
if input_name in fp32_initializers:
# For Resize/GroupNorm, only the first input can be float16
use_fp32_weight = is_node_blocked or (n.op_type in ["Resize", "GroupNorm"] and i != 0)
use_fp32_weight = is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
)
fp32_initializers[input_name].add_node(n, use_fp32_weight)

if is_node_blocked:
Expand All @@ -340,7 +348,7 @@ def convert_float_to_float16(
break

# For Resize/GroupNorm, attribute data type cannot be changed
if n.op_type not in ["Resize", "GroupNorm"]:
if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict:
for attr in n.attribute:
next_level.append(attr) # noqa: PERF402
else:
Expand Down Expand Up @@ -388,7 +396,7 @@ def convert_float_to_float16(
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
for i, input_name in enumerate(node.input):
if i not in ALWAYS_FLOAT_INPUTS[node.op_type]:
if i not in ALWAYS_FLOAT_INPUTS[node.op_type] or i in force_fp16_inputs_dict.get(node.op_type, []):
continue
for value_info in value_info_list:
if input_name == value_info.name:
Expand Down
Loading

0 comments on commit 6499301

Please sign in to comment.