diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1e82b853c2186..f90f8cff2fead 100755 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2176,6 +2176,8 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
Activation after group normalization: 0 for None, 1 for Swish
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
The epsilon value to use to avoid division by zero
groups : int (required)
@@ -2186,7 +2188,7 @@ This version of the operator has been available since version 1 of the 'com.micr
X : T
-
Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
gamma : M
1D gamma tensor for normalization with shape (C), where C is number of channels
beta : M
@@ -2205,7 +2207,7 @@ This version of the operator has been available since version 1 of the 'com.micr
T : tensor(float16), tensor(float)
Constrain input X and output Y types to float tensors.
-
M : tensor(float)
+
M : tensor(float16), tensor(float)
Constrain gamma and beta to float tensors.
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 0fa7627414ffc..301b2e76b1b2d 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -66,6 +66,8 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish use_swish_activation_ = (activation == 1); + + channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -74,6 +76,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* beta = context->Input(2); Tensor* output = context->Output(0, input->Shape()); + if (!channels_last_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "only the channels_last layout is supported"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 8578a1642198f..52c006e6bdb96 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -20,6 +20,7 @@ class GroupNorm final : public CudaKernel { bool use_swish_activation_; float epsilon_; int num_groups_; + bool channels_last_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index 112ac10c38707..dd7cae61bc66b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -68,6 +68,8 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish use_swish_activation_ = (activation == 1); + + channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -76,6 +78,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* beta = context->Input(2); Tensor* output = context->Output(0, input->Shape()); + if (!channels_last_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "only the channels_last layout is supported"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c6d3db7fbe6da..c2f5edaa6149b 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -44,9 +44,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Attr("activation", "Activation after group normalization: 0 for None, 1 for Swish", AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) .Input(0, "X", - "Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data", "T") .Input(1, "gamma", @@ -61,7 +65,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The output tensor of the same shape as X", "T") .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") - .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp index 66215877030c4..fed0e4645ffd8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp @@ -19,6 +19,7 @@ class DmlOperatorGroupNorm : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); + const bool channelsLast = kernelCreationContext.GetOptionalAttribute(AttrName::ChannelsLast, true); const float epsilon = kernelCreationContext.GetOptionalAttribute(AttrName::Epsilon, DefaultEpsilon); const bool activation = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Activation)); const uint32_t groups = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Groups)); @@ -36,19 +37,32 @@ class DmlOperatorGroupNorm : public DmlOperator // Data is in NHWC format const uint32_t batch = inputTensorShape[0]; - const uint32_t height = inputTensorShape[1]; - const uint32_t width = inputTensorShape[2]; - const uint32_t channels = inputTensorShape[3]; + const uint32_t height = channelsLast ? inputTensorShape[1] : inputTensorShape[2]; + const uint32_t width = channelsLast ? inputTensorShape[2] : inputTensorShape[3]; + const uint32_t channels = channelsLast ? inputTensorShape[3] : inputTensorShape[1]; ML_CHECK_VALID_ARGUMENT(gammaTensorShape[0] == channels); ML_CHECK_VALID_ARGUMENT(betaTensorShape[0] == channels); ML_CHECK_VALID_ARGUMENT(channels % groups == 0); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[1].GetDmlDataType() == m_inputTensorDescs[2].GetDmlDataType()); const uint32_t channelsPerGroup = channels / groups; - // 1. Reshape the input from [batch, height, width, channels] to [batch, height * width, groups, channelsPerGroup] - // 2. Stride the reshaped input from [batch, height * width, groups, channelsPerGroup] to [batch, groups, channelsPerGroup, height * width] - const std::array inputShape = {batch, groups, channelsPerGroup, height * width}; - const std::array inputStrides = {channelsPerGroup * height * width * groups, channelsPerGroup, 1, groups * channelsPerGroup}; + std::array inputShape; + std::array inputStrides; + + if (channelsLast) + { + // 1. Reshape the input from [batch, height, width, channels] to [batch, height * width, groups, channelsPerGroup] + // 2. Stride the reshaped input from [batch, height * width, groups, channelsPerGroup] to [batch, groups, channelsPerGroup, height * width] + inputShape = {batch, groups, channelsPerGroup, height * width}; + inputStrides = {channelsPerGroup * height * width * groups, channelsPerGroup, 1, groups * channelsPerGroup}; + } + else + { + // Reshape the input from [batch, channels, height, width] to [batch, groups, channelsPerGroup, height * width] + inputShape = {batch, groups, channelsPerGroup, height * width}; + inputStrides = {groups * channelsPerGroup * height * width, channelsPerGroup * height * width, height * width, 1}; + } + TensorDesc inputTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), inputShape, inputStrides); const DML_TENSOR_DESC inputDmlTensorDesc = inputTensorDesc.GetDmlDesc(); @@ -119,24 +133,36 @@ class DmlOperatorGroupNorm : public DmlOperator uint32_t currentNodeIndex = 0; - const uint32_t transposeInputIndex = currentNodeIndex++; - opDescs.push_back(&dmlTransposeInputDesc); - const uint32_t mvnNodeIndex = currentNodeIndex++; opDescs.push_back(&dmlMvnDesc); - DML_INPUT_GRAPH_EDGE_DESC inputEdge{}; - inputEdge.GraphInputIndex = 0; - inputEdge.ToNodeIndex = transposeInputIndex; - inputEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC transposeInputToMvnEdge = {}; - transposeInputToMvnEdge.FromNodeIndex = transposeInputIndex; - transposeInputToMvnEdge.FromNodeOutputIndex = 0; - transposeInputToMvnEdge.ToNodeIndex = mvnNodeIndex; - transposeInputToMvnEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(transposeInputToMvnEdge); + // We only need a transpose the input when the layout is NHWC + if (channelsLast) + { + const uint32_t transposeInputIndex = currentNodeIndex++; + opDescs.push_back(&dmlTransposeInputDesc); + + DML_INPUT_GRAPH_EDGE_DESC inputEdge{}; + inputEdge.GraphInputIndex = 0; + inputEdge.ToNodeIndex = transposeInputIndex; + inputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC transposeInputToMvnEdge = {}; + transposeInputToMvnEdge.FromNodeIndex = transposeInputIndex; + transposeInputToMvnEdge.FromNodeOutputIndex = 0; + transposeInputToMvnEdge.ToNodeIndex = mvnNodeIndex; + transposeInputToMvnEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(transposeInputToMvnEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC inputEdge{}; + inputEdge.GraphInputIndex = 0; + inputEdge.ToNodeIndex = mvnNodeIndex; + inputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputEdge); + } if (gammaBetaCastNeeded) { @@ -224,21 +250,33 @@ class DmlOperatorGroupNorm : public DmlOperator } else { - const uint32_t transposeOutputNodeIndex = currentNodeIndex++; - opDescs.push_back(&dmlTransposeOutputDesc); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC mvnToTransposeOutputEdge = {}; - mvnToTransposeOutputEdge.FromNodeIndex = mvnNodeIndex; - mvnToTransposeOutputEdge.FromNodeOutputIndex = 0; - mvnToTransposeOutputEdge.ToNodeIndex = transposeOutputNodeIndex; - mvnToTransposeOutputEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(mvnToTransposeOutputEdge); - - DML_OUTPUT_GRAPH_EDGE_DESC transposeOutputToOutputEdge{}; - transposeOutputToOutputEdge.FromNodeIndex = transposeOutputNodeIndex; - transposeOutputToOutputEdge.FromNodeOutputIndex = 0; - transposeOutputToOutputEdge.GraphOutputIndex = 0; - outputEdges.push_back(transposeOutputToOutputEdge); + if (channelsLast) + { + // We only need a transpose the output when the layout is NHWC + const uint32_t transposeOutputNodeIndex = currentNodeIndex++; + opDescs.push_back(&dmlTransposeOutputDesc); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC mvnToTransposeOutputEdge = {}; + mvnToTransposeOutputEdge.FromNodeIndex = mvnNodeIndex; + mvnToTransposeOutputEdge.FromNodeOutputIndex = 0; + mvnToTransposeOutputEdge.ToNodeIndex = transposeOutputNodeIndex; + mvnToTransposeOutputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(mvnToTransposeOutputEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC transposeOutputToOutputEdge{}; + transposeOutputToOutputEdge.FromNodeIndex = transposeOutputNodeIndex; + transposeOutputToOutputEdge.FromNodeOutputIndex = 0; + transposeOutputToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(transposeOutputToOutputEdge); + } + else + { + DML_OUTPUT_GRAPH_EDGE_DESC mvnToOutputEdge{}; + mvnToOutputEdge.FromNodeIndex = mvnNodeIndex; + mvnToOutputEdge.FromNodeOutputIndex = 0; + mvnToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(mvnToOutputEdge); + } } MLOperatorGraphDesc operatorGraphDesc = {}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 5be84a931f4f1..dac128f92ae0c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -23,6 +23,7 @@ namespace AttrName static constexpr const char* BlockSize = "blocksize"; static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; + static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* CeilMode = "ceil_mode"; static constexpr const char* Clip = "clip"; static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode"; diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index f1f19f3eaaf5b..02a260b784621 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -173,6 +173,7 @@ def convert_float_to_float16( op_block_list=None, node_block_list=None, force_fp16_initializers=False, + force_fp16_inputs=None, ): """Convert tensor float type in the input ONNX model to tensor float16. @@ -190,6 +191,8 @@ def convert_float_to_float16( node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None. force_fp16_initializers(bool): force converting all float initializers to float16. Default to false, which will convert only the one needed to avoid precision loss. + force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if + this script's preference it to keep them in float32. Raises: ValueError: input type is not ModelProto. @@ -201,6 +204,8 @@ def convert_float_to_float16( ), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05" assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504" + force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs + if isinstance(model, str): model_path = model if version.parse(onnx.__version__) >= version.parse("1.8.0") and not disable_shape_infer: @@ -327,7 +332,10 @@ def convert_float_to_float16( for i, input_name in enumerate(n.input): if input_name in fp32_initializers: # For Resize/GroupNorm, only the first input can be float16 - use_fp32_weight = is_node_blocked or (n.op_type in ["Resize", "GroupNorm"] and i != 0) + use_fp32_weight = is_node_blocked or ( + i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) + and i not in force_fp16_inputs_dict.get(n.op_type, []) + ) fp32_initializers[input_name].add_node(n, use_fp32_weight) if is_node_blocked: @@ -340,7 +348,7 @@ def convert_float_to_float16( break # For Resize/GroupNorm, attribute data type cannot be changed - if n.op_type not in ["Resize", "GroupNorm"]: + if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict: for attr in n.attribute: next_level.append(attr) # noqa: PERF402 else: @@ -388,7 +396,7 @@ def convert_float_to_float16( # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. for node in mixed_float_type_node_list: for i, input_name in enumerate(node.input): - if i not in ALWAYS_FLOAT_INPUTS[node.op_type]: + if i not in ALWAYS_FLOAT_INPUTS[node.op_type] or i in force_fp16_inputs_dict.get(node.op_type, []): continue for value_info in value_info_list: if input_name == value_info.name: diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index f866bfce86d60..2cae366d3f9bd 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -14,8 +14,9 @@ class FusionGroupNorm(Fusion): - def __init__(self, model: OnnxModel): + def __init__(self, model: OnnxModel, channels_last=True): super().__init__(model, "GroupNorm", "Add") + self.channels_last = channels_last def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ @@ -145,40 +146,47 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): input_name = root output_name = last_node.output[0] + group_norm_input_name = input_name + "_NHWC" if self.channels_last else input_name + group_norm_output_name = output_name + "_NHWC" if self.channels_last else output_name + # NCHW to NHWC - transpose_input = helper.make_node( - "Transpose", - [input_name], - [input_name + "_NHWC"], - name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), - perm=[0, 2, 3, 1], - ) + if self.channels_last: + transpose_input = helper.make_node( + "Transpose", + [input_name], + [group_norm_input_name], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), + perm=[0, 2, 3, 1], + ) + self.nodes_to_add.append(transpose_input) + self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name new_node = helper.make_node( "GroupNorm", - inputs=[input_name + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"], - outputs=[output_name + "_NHWC"], + inputs=[group_norm_input_name, group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[group_norm_output_name], name=group_norm_name, ) new_node.attribute.extend(instance_norm.attribute) new_node.attribute.extend([helper.make_attribute("groups", 32)]) new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) - new_node.domain = "com.microsoft" - # NHWC to NCHW - transpose_output = helper.make_node( - "Transpose", - [output_name + "_NHWC"], - [output_name], - name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), - perm=[0, 3, 1, 2], - ) + if not self.channels_last: + new_node.attribute.extend([helper.make_attribute("channels_last", 0)]) + new_node.domain = "com.microsoft" self.nodes_to_add.append(new_node) - self.nodes_to_add.append(transpose_input) - self.nodes_to_add.append(transpose_output) - self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name - self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name + + # NHWC to NCHW + if self.channels_last: + transpose_output = helper.make_node( + "Transpose", + [group_norm_output_name], + [output_name], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), + perm=[0, 3, 1, 2], + ) + self.nodes_to_add.append(transpose_output) + self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 59a49a248dd40..57f0fea99d145 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -43,6 +43,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False + self.group_norm_channels_last = True # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. @@ -103,6 +104,8 @@ def parse(args): options.disable_attention_mask() if args.model_type in ["unet", "vae", "clip"]: + if args.use_group_norm_channels_first: + options.group_norm_channels_last = False if args.disable_nhwc_conv: options.enable_nhwc_conv = False if args.disable_group_norm: @@ -280,3 +283,11 @@ def add_arguments(parser: ArgumentParser): help="Do not use NhwcConv. Only works for model_type=unet or vae", ) parser.set_defaults(disable_nhwc_conv=False) + + parser.add_argument( + "--use_group_norm_channels_first", + required=False, + action="store_true", + help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(use_group_norm_channels_first=False) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 7d3b36365015a..4f74da577dfee 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -627,6 +627,8 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): Default to false. min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. max_finite_val (float, optional): maximal finite value. Defaults to 1e4. + force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if + this script's preference it to keep them in float32. """ if "keep_io_types" not in kwargs: kwargs["keep_io_types"] = True @@ -677,6 +679,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): "op_block_list", "node_block_list", "force_fp16_initializers", + "force_fp16_inputs", ] if key in kwargs } diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 00fc0763d820c..294641dd1e067 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -128,7 +128,8 @@ def optimize(self, options: Optional[FusionOptions] = None): self.fuse_reshape() if (options is None) or options.enable_group_norm: - group_norm_fusion = FusionGroupNorm(self) + channels_last = (options is None) or options.group_norm_channels_last + group_norm_fusion = FusionGroupNorm(self, channels_last) group_norm_fusion.apply() insert_transpose_fusion = FusionInsertTranspose(self) diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc index f0c6488a35d57..b02c13570279c 100644 --- a/onnxruntime/test/contrib_ops/group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -21,8 +21,8 @@ TEST(GroupNormTest, GroupNorm_128) { constexpr int64_t H = 2; constexpr int64_t W = 2; - std::vector dims{B, H, W, C}; - std::vector input_data = { + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { 0.696469f, 0.719469f, 0.480932f, 0.438572f, 0.182492f, 0.634401f, 0.722443f, 0.293714f, 0.430863f, 0.426351f, 0.623953f, 0.866309f, 0.519485f, 0.603060f, 0.417022f, 0.669314f, 0.842342f, 0.194223f, 0.627249f, 0.556785f, 0.318766f, 0.925132f, 0.304768f, 0.355915f, 0.151127f, 0.513128f, 0.321981f, 0.854452f, 0.171082f, 0.578551f, @@ -127,6 +127,112 @@ TEST(GroupNormTest, GroupNorm_128) { 0.623269f, 0.016948f, 0.826530f, 0.308751f, 0.290656f, 0.058387f, 0.264397f, 0.294895f, 0.639992f, 0.489059f, 0.343698f, 0.929770f, 0.390125f, 0.397707f}; + std::vector dims_nchw{B, C, H, W}; + std::vector input_data_nchw = { + 0.696469f, 0.286139f, 0.226851f, 0.551315f, 0.719469f, 0.423106f, 0.980764f, 0.684830f, 0.480932f, 0.392118f, + 0.343178f, 0.729050f, 0.438572f, 0.059678f, 0.398044f, 0.737995f, 0.182492f, 0.175452f, 0.531551f, 0.531828f, + 0.634401f, 0.849432f, 0.724455f, 0.611024f, 0.722443f, 0.322959f, 0.361789f, 0.228263f, 0.293714f, 0.630976f, + 0.092105f, 0.433701f, 0.430863f, 0.493685f, 0.425830f, 0.312261f, 0.426351f, 0.893389f, 0.944160f, 0.501837f, + 0.623953f, 0.115618f, 0.317285f, 0.414826f, 0.866309f, 0.250455f, 0.483034f, 0.985560f, 0.519485f, 0.612895f, + 0.120629f, 0.826341f, 0.603060f, 0.545068f, 0.342764f, 0.304121f, 0.417022f, 0.681301f, 0.875457f, 0.510422f, + 0.669314f, 0.585937f, 0.624904f, 0.674689f, 0.842342f, 0.083195f, 0.763683f, 0.243666f, 0.194223f, 0.572457f, + 0.095713f, 0.885327f, 0.627249f, 0.723416f, 0.016129f, 0.594432f, 0.556785f, 0.158960f, 0.153071f, 0.695530f, + 0.318766f, 0.691970f, 0.554383f, 0.388951f, 0.925132f, 0.841670f, 0.357398f, 0.043591f, 0.304768f, 0.398186f, + 0.704959f, 0.995358f, 0.355915f, 0.762548f, 0.593177f, 0.691702f, 0.151127f, 0.398876f, 0.240856f, 0.343456f, + 0.513128f, 0.666625f, 0.105908f, 0.130895f, 0.321981f, 0.661564f, 0.846506f, 0.553257f, 0.854452f, 0.384838f, + 0.316788f, 0.354265f, 0.171082f, 0.829113f, 0.338671f, 0.552370f, 0.578551f, 0.521533f, 0.002688f, 0.988345f, + 0.905342f, 0.207636f, 0.292489f, 0.520010f, 0.901911f, 0.983631f, 0.257542f, 0.564359f, 0.806969f, 0.394370f, + 0.731073f, 0.161069f, 0.600699f, 0.865864f, 0.983522f, 0.079366f, 0.428347f, 0.204543f, 0.450636f, 0.547764f, + 0.093327f, 0.296861f, 0.927584f, 0.569004f, 0.457412f, 0.753526f, 0.741862f, 0.048579f, 0.708697f, 0.839243f, + 0.165938f, 0.780998f, 0.286537f, 0.306470f, 0.665261f, 0.111392f, 0.664872f, 0.887857f, 0.696311f, 0.440328f, + 0.438214f, 0.765096f, 0.565642f, 0.084904f, 0.582671f, 0.814844f, 0.337066f, 0.927577f, 0.750717f, 0.574064f, + 0.751644f, 0.079149f, 0.859389f, 0.821504f, 0.909872f, 0.128631f, 0.081780f, 0.138416f, 0.399379f, 0.424307f, + 0.562218f, 0.122244f, 0.201400f, 0.811644f, 0.467988f, 0.807938f, 0.007426f, 0.551593f, 0.931932f, 0.582175f, + 0.206096f, 0.717758f, 0.378986f, 0.668384f, 0.029320f, 0.635900f, 0.032198f, 0.744781f, 0.472913f, 0.121754f, + 0.542636f, 0.066774f, 0.653365f, 0.996086f, 0.769397f, 0.573774f, 0.102635f, 0.699834f, 0.661168f, 0.049097f, + 0.792299f, 0.518717f, 0.425868f, 0.788187f, 0.411569f, 0.481026f, 0.181629f, 0.321319f, 0.845533f, 0.186904f, + 0.417291f, 0.989035f, 0.236600f, 0.916832f, 0.918397f, 0.091296f, 0.463653f, 0.502216f, 0.313669f, 0.047340f, + 0.241686f, 0.095530f, 0.238250f, 0.807791f, 0.894978f, 0.043223f, 0.301947f, 0.980582f, 0.539505f, 0.626309f, + 0.005545f, 0.484909f, 0.988329f, 0.375186f, 0.097038f, 0.461909f, 0.963004f, 0.341831f, 0.798923f, 0.798846f, + 0.208248f, 0.443368f, 0.715601f, 0.410520f, 0.191007f, 0.967494f, 0.650750f, 0.865460f, 0.025242f, 0.266906f, + 0.502071f, 0.067449f, 0.993033f, 0.236462f, 0.374292f, 0.214012f, 0.105446f, 0.232480f, 0.300610f, 0.634442f, + 0.281235f, 0.362277f, 0.005943f, 0.365719f, 0.533886f, 0.162016f, 0.597433f, 0.293152f, 0.632050f, 0.026197f, + 0.887593f, 0.016119f, 0.126958f, 0.777162f, 0.045895f, 0.710999f, 0.971046f, 0.871683f, 0.710162f, 0.958510f, + 0.429813f, 0.872879f, 0.355958f, 0.929764f, 0.148778f, 0.940029f, 0.832716f, 0.846055f, 0.123923f, 0.596487f, + 0.016392f, 0.721184f, 0.007738f, 0.084822f, 0.225498f, 0.875125f, 0.363576f, 0.539960f, 0.568103f, 0.225463f, + 0.572147f, 0.660952f, 0.298245f, 0.418627f, 0.453089f, 0.932351f, 0.587494f, 0.948252f, 0.556035f, 0.500561f, + 0.003532f, 0.480889f, 0.927455f, 0.198366f, 0.052091f, 0.406779f, 0.372396f, 0.857153f, 0.026611f, 0.920149f, + 0.680903f, 0.904226f, 0.607529f, 0.811953f, 0.335544f, 0.349566f, 0.389874f, 0.754797f, 0.369291f, 0.242220f, + 0.937668f, 0.908011f, 0.348797f, 0.634638f, 0.273842f, 0.206115f, 0.336340f, 0.327100f, 0.882276f, 0.822304f, + 0.709623f, 0.959345f, 0.422543f, 0.245033f, 0.117398f, 0.301053f, 0.145264f, 0.092186f, 0.602932f, 0.364187f, + 0.564570f, 0.191336f, 0.676906f, 0.215505f, 0.278024f, 0.741760f, 0.559738f, 0.334836f, 0.542989f, 0.693985f, + 0.912132f, 0.580713f, 0.232686f, 0.746698f, 0.777769f, 0.200401f, 0.820574f, 0.464935f, 0.779767f, 0.237478f, + 0.332580f, 0.953697f, 0.657815f, 0.772878f, 0.688374f, 0.204304f, 0.470689f, 0.808964f, 0.675035f, 0.006028f, + 0.087408f, 0.346795f, 0.944366f, 0.491190f, 0.270176f, 0.360424f, 0.210653f, 0.421200f, 0.218035f, 0.845753f, + 0.456271f, 0.279802f, 0.932892f, 0.314351f, 0.909715f, 0.043418f, 0.707115f, 0.483889f, 0.444221f, 0.036323f, + 0.040683f, 0.332754f, 0.947120f, 0.617660f, 0.368875f, 0.611977f, 0.206132f, 0.165066f, 0.361817f, 0.863353f, + 0.509402f, 0.296902f, 0.950252f, 0.815966f, 0.322974f, 0.972098f, 0.987351f, 0.408660f, 0.655923f, 0.405653f, + 0.257348f, 0.082653f, 0.263610f, 0.271480f, 0.398639f, 0.184886f, 0.953818f, 0.102880f, 0.625209f, 0.441697f, + 0.423518f, 0.371992f, 0.868315f, 0.280477f, 0.020576f, 0.918097f, 0.864480f, 0.276902f, 0.523488f, 0.109088f, + 0.093427f, 0.837466f, 0.410266f, 0.661717f, 0.943201f, 0.245131f, 0.013160f, 0.024148f, 0.709386f, 0.924552f, + 0.467330f, 0.375109f, 0.542860f, 0.858917f, 0.652154f, 0.232980f, 0.774580f, 0.134613f, 0.165560f, 0.612682f, + 0.238783f, 0.704779f, 0.349519f, 0.277424f, 0.998918f, 0.040616f, 0.645823f, 0.038700f, 0.760210f, 0.230090f, + 0.089832f, 0.648450f, 0.732601f, 0.678095f, 0.051901f, 0.294307f, 0.451088f, 0.287103f, 0.810513f, 0.131115f, + 0.612179f, 0.988215f, 0.902557f, 0.222157f, 0.000082f, 0.980597f, 0.882713f, 0.919472f, 0.415504f, 0.744615f, + 0.212831f, 0.392304f, 0.851548f, 0.127612f, 0.893865f, 0.496508f, 0.426096f, 0.305646f, 0.916849f, 0.517623f, + 0.804026f, 0.857652f, 0.922382f, 0.303381f, 0.339811f, 0.595074f, 0.441324f, 0.932843f, 0.397564f, 0.477778f, + 0.617186f, 0.404739f, 0.992478f, 0.098851f, 0.220603f, 0.322655f, 0.147723f, 0.284219f, 0.779245f, 0.522892f, + 0.033954f, 0.982623f, 0.616006f, 0.058939f, 0.661169f, 0.378369f, 0.135673f, 0.563665f, 0.727080f, 0.671127f, + 0.247513f, 0.524866f, 0.537663f, 0.716803f, 0.359867f, 0.797733f, 0.627922f, 0.038332f, 0.546479f, 0.861912f, + 0.567574f, 0.175828f, 0.510376f, 0.756946f, 0.110105f, 0.817099f, 0.167482f, 0.534076f, 0.385743f, 0.248624f, + 0.647433f, 0.037392f, 0.760046f, 0.526941f, 0.875771f, 0.520718f, 0.035033f, 0.143601f, 0.795605f, 0.491976f, + 0.441879f, 0.318435f, 0.284549f, 0.965886f, 0.432969f, 0.884003f, 0.648163f, 0.858428f, 0.852450f, 0.956312f, + 0.697942f, 0.805397f, 0.733128f, 0.605227f, 0.717354f, 0.715750f, 0.040908f, 0.516111f, 0.792651f, 0.242962f, + 0.465148f, 0.434986f, 0.402787f, 0.121840f, 0.525712f, 0.446248f, 0.663393f, 0.549413f, 0.027543f, 0.031918f, + 0.701360f, 0.707581f, 0.959939f, 0.876705f, 0.468060f, 0.625907f, 0.457182f, 0.222946f, 0.376677f, 0.103884f, + 0.666527f, 0.192030f, 0.475468f, 0.967437f, 0.031669f, 0.151730f, 0.298579f, 0.941807f, 0.908842f, 0.162001f, + 0.981118f, 0.750748f, 0.539977f, 0.931703f, 0.880607f, 0.391316f, 0.656343f, 0.647385f, 0.326968f, 0.179390f, + 0.466810f, 0.263281f, 0.355065f, 0.954144f, 0.461138f, 0.684891f, 0.336230f, 0.995861f, 0.658768f, 0.196009f, + 0.098184f, 0.943181f, 0.944778f, 0.621328f, 0.016991f, 0.225535f, 0.801277f, 0.875460f, 0.453990f, 0.365521f, + 0.274225f, 0.116971f, 0.115745f, 0.952603f, 0.808626f, 0.164779f, 0.207050f, 0.655552f, 0.764664f, 0.810315f, + 0.163338f, 0.984128f, 0.227802f, 0.589415f, 0.587616f, 0.967362f, 0.657667f, 0.584904f, 0.518773f, 0.764658f, + 0.106055f, 0.002092f, 0.952489f, 0.498658f, 0.328335f, 0.368053f, 0.803843f, 0.382370f, 0.770169f, 0.440462f, + 0.844077f, 0.076204f, 0.481128f, 0.466850f, 0.264328f, 0.943615f, 0.905028f, 0.443596f, 0.097160f, 0.206783f, + 0.271492f, 0.484220f, 0.338377f, 0.774136f, 0.476027f, 0.870371f, 0.995782f, 0.219836f, 0.611671f, 0.847502f, + 0.945237f, 0.290086f, 0.727043f, 0.015016f, 0.879142f, 0.063939f, 0.733395f, 0.994610f, 0.501190f, 0.209334f, + 0.594644f, 0.624150f, 0.668073f, 0.172612f, 0.898713f, 0.620991f, 0.043569f, 0.684041f, 0.196084f, 0.027341f, + 0.550953f, 0.813314f, 0.859941f, 0.103521f, 0.663043f, 0.710075f, 0.294517f, 0.971364f, 0.278687f, 0.069982f, + 0.519280f, 0.694315f, 0.244660f, 0.338582f, 0.563628f, 0.886678f, 0.747326f, 0.209592f, 0.251777f, 0.523881f, + 0.768959f, 0.618762f, 0.501324f, 0.597125f, 0.756060f, 0.537080f, 0.897753f, 0.947067f, 0.915355f, 0.754518f, + 0.246321f, 0.385271f, 0.280000f, 0.657660f, 0.324222f, 0.754392f, 0.113509f, 0.775365f, 0.585902f, 0.835389f, + 0.430876f, 0.624964f, 0.554412f, 0.975671f, 0.755474f, 0.544813f, 0.174032f, 0.904114f, 0.205838f, 0.650043f, + 0.936472f, 0.223580f, 0.225924f, 0.851819f, 0.827655f, 0.351703f, 0.265096f, 0.127388f, 0.987936f, 0.835343f, + 0.899392f, 0.513679f, 0.114385f, 0.052580f, 0.330582f, 0.920330f, 0.947582f, 0.841164f, 0.158679f, 0.419923f, + 0.246243f, 0.205350f, 0.684826f, 0.486112f, 0.324910f, 0.100214f, 0.544763f, 0.347025f, 0.391096f, 0.310509f, + 0.387195f, 0.555860f, 0.014144f, 0.847647f, 0.921920f, 0.550530f, 0.268021f, 0.990239f, 0.383194f, 0.693655f, + 0.689953f, 0.434309f, 0.199158f, 0.966579f, 0.063691f, 0.485149f, 0.220731f, 0.293974f, 0.828527f, 0.367266f, + 0.083348f, 0.196309f, 0.860373f, 0.977029f, 0.267982f, 0.675409f, 0.081199f, 0.723466f, 0.416437f, 0.918160f, + 0.311536f, 0.941467f, 0.503247f, 0.348893f, 0.647020f, 0.249746f, 0.229764f, 0.196346f, 0.959900f, 0.492914f, + 0.751615f, 0.473992f, 0.587540f, 0.584139f, 0.979886f, 0.668433f, 0.239769f, 0.015198f, 0.218682f, 0.455520f, + 0.393420f, 0.812326f, 0.785557f, 0.089096f, 0.952011f, 0.527457f, 0.596404f, 0.405057f, 0.649501f, 0.871326f, + 0.673936f, 0.970099f, 0.701122f, 0.821721f, 0.045040f, 0.672699f, 0.654753f, 0.101746f, 0.842387f, 0.614172f, + 0.098328f, 0.594467f, 0.478416f, 0.233294f, 0.019756f, 0.365567f, 0.619851f, 0.329279f, 0.307255f, 0.751121f, + 0.758625f, 0.718766f, 0.101182f, 0.516166f, 0.557799f, 0.744805f, 0.903178f, 0.369039f, 0.428663f, 0.732767f, + 0.662636f, 0.557870f, 0.350140f, 0.195352f, 0.183807f, 0.081583f, 0.081201f, 0.845798f, 0.383673f, 0.060740f, + 0.896426f, 0.223270f, 0.268124f, 0.194498f, 0.967501f, 0.112540f, 0.722163f, 0.932089f, 0.668001f, 0.858727f, + 0.242447f, 0.673928f, 0.700871f, 0.458333f, 0.870546f, 0.694386f, 0.894878f, 0.753204f, 0.520290f, 0.498688f, + 0.453728f, 0.021647f, 0.535141f, 0.422973f, 0.157534f, 0.119070f, 0.449352f, 0.039913f, 0.986580f, 0.378121f, + 0.382109f, 0.051126f, 0.426672f, 0.015745f, 0.030094f, 0.339099f, 0.820969f, 0.458821f, 0.014841f, 0.163220f, + 0.739923f, 0.738294f, 0.754523f, 0.351669f, 0.352277f, 0.802076f, 0.398138f, 0.727191f, 0.581123f, 0.364342f, + 0.080007f, 0.116125f, 0.889559f, 0.452341f, 0.994005f, 0.363897f, 0.249954f, 0.350539f, 0.343086f, 0.637357f, + 0.012738f, 0.763269f, 0.416415f, 0.432239f, 0.481115f, 0.449212f, 0.497471f, 0.345904f, 0.453346f, 0.404651f, + 0.518243f, 0.623269f, 0.241041f, 0.508437f, 0.594622f, 0.016948f, 0.520494f, 0.239293f, 0.404539f, 0.826530f, + 0.326236f, 0.483217f, 0.024741f, 0.308751f, 0.639721f, 0.315162f, 0.205798f, 0.290656f, 0.954378f, 0.086802f, + 0.463358f, 0.058387f, 0.538658f, 0.146036f, 0.634085f, 0.264397f, 0.690915f, 0.347146f, 0.004168f, 0.294895f, + 0.081894f, 0.495040f, 0.288890f, 0.639992f, 0.499936f, 0.036045f, 0.318634f, 0.489059f, 0.572204f, 0.104871f, + 0.649971f, 0.343698f, 0.182921f, 0.805327f, 0.068623f, 0.929770f, 0.706266f, 0.475591f, 0.011161f, 0.390125f, + 0.645798f, 0.858913f, 0.617764f, 0.397707f}; + std::vector gamma_data = { 0.447359f, 0.873295f, 0.351357f, 0.065158f, 0.442673f, 0.998459f, 0.379773f, 0.193055f, 0.045130f, 0.170969f, 0.324064f, 0.574278f, 0.665588f, 0.042819f, 0.936180f, 0.235638f, 0.149062f, 0.530829f, 0.677586f, 0.307253f, @@ -157,7 +263,7 @@ TEST(GroupNormTest, GroupNorm_128) { 0.413157f, 0.595286f, 0.133620f, 0.484188f, 0.972134f, 0.427721f, 0.242881f, 0.927507f, 0.610774f, 0.727857f, 0.543405f, 0.011202f, 0.755700f, 0.978697f, 0.716188f, 0.808757f, 0.851587f, 0.999201f}; - std::vector norm_data = { + std::vector norm_data_nhwc = { 0.406306f, 1.632045f, 0.095849f, 0.919355f, -0.458834f, 1.632483f, 0.876482f, 0.729815f, 0.750835f, 0.782631f, 0.590117f, 1.476163f, 0.183714f, 0.057787f, -0.474648f, 0.143954f, 0.561618f, 0.031635f, 0.426744f, 0.118848f, 0.054676f, 0.526575f, -0.827396f, -0.206514f, 0.631899f, 1.033381f, -0.028056f, @@ -273,7 +379,123 @@ TEST(GroupNormTest, GroupNorm_128) { 0.181189f, 0.244843f, 1.995885f, -1.411448f, 1.422581f, 0.658642f, 0.243404f, 0.442854f, 0.230959f, -0.272532f, 0.778544f, 1.461264f, 0.670758f, 2.274148f, 0.642745f, 0.948315f}; - std::vector swish_data = { + std::vector norm_data_nchw = { + 0.406306f, -0.397960f, -0.514167f, 0.121796f, 1.632045f, 0.498094f, 2.631821f, 1.499508f, 0.095849f, + -0.040874f, -0.116213f, 0.477808f, 0.919355f, 0.811189f, 0.907785f, 1.004834f, -0.458834f, -0.472885f, + 0.237840f, 0.238391f, 1.632483f, 2.600490f, 2.037882f, 1.527244f, 0.876482f, 0.192458f, 0.258945f, + 0.030314f, 0.729815f, 1.023374f, 0.554331f, 0.851662f, 0.750835f, 0.762038f, 0.749937f, 0.729685f, + 0.782631f, 1.098150f, 1.132450f, 0.833627f, 0.590117f, -0.060817f, 0.197422f, 0.322326f, 1.476163f, + 0.078648f, 0.606424f, 1.746771f, 0.183714f, 0.518953f, -1.247748f, 1.284994f, 0.057787f, 0.044398f, + -0.002311f, -0.011233f, -0.474648f, 0.859423f, 1.839518f, -0.003167f, 0.143954f, 0.038016f, 0.087527f, + 0.150784f, 0.561618f, 0.176986f, 0.521764f, 0.258291f, 0.031635f, 0.714081f, -0.146106f, 1.278590f, + 0.426744f, 0.648229f, -0.980738f, 0.351162f, 0.118848f, -0.296623f, -0.302773f, 0.263747f, 0.054676f, + 1.040141f, 0.676835f, 0.240001f, 0.526575f, 0.429690f, -0.132461f, -0.496733f, -0.827396f, -0.494966f, + 0.596699f, 1.630098f, -0.206514f, 1.206059f, 0.617694f, 0.959952f, 0.631899f, 0.709146f, 0.659876f, + 0.691867f, 1.033381f, 1.134488f, 0.765150f, 0.781609f, -0.028056f, 1.010104f, 1.575500f, 0.678993f, + 0.117742f, 0.117495f, 0.117460f, 0.117479f, -0.928939f, 0.857719f, -0.473908f, 0.106319f, 0.254703f, + 0.187595f, -0.423069f, 0.737017f, 1.002641f, -0.713799f, -0.505049f, 0.054679f, 0.056505f, 0.068448f, + -0.037672f, 0.007170f, 0.502409f, 0.201653f, 0.447086f, 0.031594f, 0.186869f, 0.252268f, 0.281287f, + 0.058290f, -0.032152f, -0.172338f, -0.018190f, 0.042648f, -0.201724f, 0.070818f, 0.915389f, 0.435231f, + 0.683548f, 1.228964f, 1.207481f, -0.069486f, 0.900928f, 1.349056f, -0.962214f, 1.149115f, 0.126877f, + 0.173722f, 1.016921f, -0.284731f, 1.073324f, 1.663625f, 1.156551f, 0.478892f, -0.017409f, 0.077027f, + 0.019404f, -0.119480f, 0.957481f, 1.191751f, 0.709657f, 1.305503f, 0.710492f, 0.094092f, 0.713726f, + -1.632824f, 1.254686f, 1.179984f, 1.354227f, -0.186219f, -0.620889f, -0.462022f, 0.270004f, 0.339929f, + 0.882544f, 0.831658f, 0.840813f, 0.911391f, 1.003820f, 1.105588f, 0.865947f, 1.028848f, 0.385277f, + 0.249245f, 0.102975f, 0.301977f, 0.814893f, 0.829719f, 0.796979f, 0.828055f, -0.841305f, 1.360636f, + 0.520542f, -0.564568f, 1.028838f, 0.624319f, 1.122967f, 1.414307f, 1.664626f, 1.011229f, -0.562413f, + 1.432279f, 0.982238f, -0.634975f, 1.328713f, 0.605853f, 0.150513f, 0.475544f, 0.137686f, 0.199995f, + -0.461095f, 0.034839f, 1.895931f, -0.442368f, -0.012286f, 1.765260f, -0.574054f, 1.540784f, 1.094831f, + 0.660444f, 0.856002f, 0.876256f, 0.900296f, 0.743193f, 0.857834f, 0.771619f, -0.437987f, 0.795097f, + 0.983861f, -0.860229f, 0.919201f, 1.088295f, 0.978393f, 1.000022f, -0.604762f, 0.300263f, 1.250703f, + 0.093107f, 0.398245f, 0.476736f, 0.584533f, 0.450905f, 1.126501f, 1.126446f, 0.704302f, 0.872359f, + 1.388226f, 0.453643f, -0.218810f, 2.159872f, 0.740287f, 1.137416f, -0.416660f, 0.030324f, 0.352386f, + -0.572652f, 1.397336f, -0.212928f, 0.833504f, 0.673148f, 0.564530f, 0.691624f, 0.614170f, 1.159168f, + 0.582539f, 0.714844f, 0.687727f, 0.829472f, 0.895726f, 0.749217f, 0.626510f, 0.160861f, 0.679485f, + -0.247668f, 0.563813f, 0.424527f, 0.442242f, 0.546163f, 0.408836f, 0.503895f, 0.541062f, 0.526861f, + 0.651389f, 1.131327f, 0.109609f, 0.965844f, 0.307533f, 0.397239f, 0.275143f, 0.398844f, 1.158524f, + 1.178295f, 0.107930f, 0.808378f, 0.360064f, 0.893187f, 0.353517f, 0.411826f, 0.588918f, 1.147333f, + 0.707609f, 0.859227f, 0.904664f, 0.005007f, 0.915281f, 1.148453f, 0.418446f, 0.581892f, 0.628682f, + 1.279391f, 0.420879f, 1.174909f, 0.355126f, 0.239180f, 0.495571f, 0.703488f, 0.897993f, 0.580433f, + 0.796672f, 0.937277f, 0.923647f, 1.115814f, 0.759542f, 1.057870f, 0.977992f, 1.052553f, 0.996513f, + 1.042361f, 0.935513f, 0.938658f, -0.328335f, 0.414783f, -0.370250f, -0.629015f, 1.636925f, 1.554468f, + -0.000332f, 0.794400f, -0.644444f, -0.841804f, -0.462323f, -0.489248f, 1.350502f, 1.139242f, 0.742310f, + 1.621988f, 0.891792f, 0.742398f, 0.634979f, 0.789545f, 0.600690f, 0.564714f, 0.910902f, 0.749079f, + 0.795602f, -0.081046f, 1.059454f, -0.024277f, 0.142066f, 2.137630f, 1.354346f, 0.386545f, -0.015730f, + 0.467942f, 1.166715f, 0.105109f, -0.867947f, 0.330163f, 0.402587f, -0.943201f, 1.039989f, 0.807147f, + 1.013271f, 0.658228f, 0.261774f, 1.276604f, 0.793169f, 0.981167f, 1.182381f, -0.094400f, 0.608214f, + 1.500447f, 0.375100f, -0.540889f, -0.429466f, -0.074319f, 0.493101f, 0.428099f, 0.396397f, 0.409342f, + -0.112225f, 0.338536f, -0.096419f, 1.247461f, 0.136779f, -0.296175f, 1.306138f, -0.211410f, 1.225890f, + -0.883684f, 0.732527f, 0.188935f, 0.158450f, -0.070659f, -0.068210f, 0.095841f, 1.142486f, 0.765356f, + 0.480573f, 0.758850f, -0.296101f, -0.351806f, -0.084915f, 0.595416f, 0.228868f, -0.067355f, 0.843406f, + 0.656214f, 0.873088f, 1.118756f, 1.124528f, 0.905517f, 0.397857f, 0.077982f, -0.111570f, -0.334851f, + 0.432766f, 0.446440f, 0.667385f, 0.295979f, 1.815673f, -0.258010f, 1.014872f, 0.567667f, 0.353312f, + 0.252682f, 1.221989f, 0.073956f, -0.006854f, 1.239576f, 1.165116f, 0.349117f, 0.251850f, -0.979634f, + -1.026174f, 1.184909f, 0.343477f, 0.825275f, 1.364619f, 0.027066f, -0.497336f, -0.463020f, 1.676924f, + 2.348872f, 0.382225f, 0.125961f, 0.592108f, 1.470366f, 0.758787f, -0.208515f, 1.041303f, -0.435509f, + 0.117172f, 1.494655f, 0.342757f, 1.778383f, 0.342274f, 0.097464f, 2.547432f, -0.706661f, 0.892228f, + 0.432844f, 0.978781f, 0.577661f, -0.293386f, 0.867343f, 1.042198f, 0.928943f, -1.206122f, -0.536458f, + -0.103338f, -0.556358f, 0.772336f, 0.736790f, 0.761959f, 0.781633f, 1.964310f, 0.328702f, -0.205143f, + 2.151912f, 0.807267f, 0.819557f, 0.651057f, 0.761094f, -0.553660f, 0.061518f, 1.635670f, -0.845767f, + 1.500599f, 0.591131f, 0.429972f, 0.154289f, 1.184999f, 0.943027f, 1.116617f, 1.149119f, 0.798352f, + -0.237060f, -0.176123f, 0.250859f, 0.738550f, 2.343516f, 0.595660f, 0.857584f, 0.334614f, 0.055512f, + 0.827656f, -0.346350f, 0.879107f, 0.903969f, 0.861351f, 0.894605f, 0.544361f, 0.112821f, -0.710248f, + 0.886723f, 1.241048f, -0.874084f, 1.412525f, 0.338762f, -0.116848f, 0.501252f, 0.737254f, 0.656447f, + 0.680143f, 0.883760f, 0.893155f, 1.024669f, 0.749525f, 0.825862f, 0.796258f, 0.693469f, 0.903967f, + 1.112298f, 0.917900f, 0.659168f, 0.521876f, 0.830550f, 0.020787f, 0.905854f, 0.044571f, 0.857847f, + 0.528776f, 0.224581f, 0.636013f, -0.774066f, 0.896313f, 0.357502f, 0.101543f, 0.048746f, -0.023476f, + -0.007332f, 1.160492f, 0.173347f, 0.010474f, -0.390864f, -0.183245f, 0.374310f, -0.061789f, 0.307303f, + 0.374511f, 0.508790f, 0.504972f, 0.571301f, 0.647929f, 0.892303f, 0.727948f, 0.437075f, 0.272462f, + 0.267807f, -1.691226f, -0.311736f, 0.221596f, -0.501987f, -0.209513f, -0.249217f, 0.477392f, -0.221902f, + 0.783358f, 0.585570f, 0.293685f, 0.168966f, -0.402073f, -0.397286f, 0.793616f, 0.814484f, 1.660988f, + 1.381788f, 0.434287f, 0.951160f, 0.398667f, -0.368342f, 0.685965f, 0.628689f, 0.746822f, 0.647196f, + 0.952972f, 1.171188f, 0.756122f, 0.809376f, -0.181046f, 1.143145f, 1.075280f, -0.462215f, 0.117678f, + 0.117596f, 0.117522f, 0.117660f, 1.207595f, -0.374746f, 0.482337f, 0.453367f, -0.074850f, -0.281733f, + 0.121187f, -0.164130f, -0.407813f, 1.347597f, -0.097000f, 0.558638f, -0.030066f, 0.084762f, 0.026081f, + -0.054476f, 0.048566f, 0.563618f, 0.564591f, 0.367439f, 0.067439f, 0.110448f, 0.229187f, 0.244487f, + 0.001379f, -0.044959f, -0.092778f, -0.175144f, -0.060172f, 0.876871f, 0.715658f, -0.005267f, 0.280818f, + 1.021856f, 1.202137f, 1.277564f, -0.846823f, 1.680601f, -0.648320f, 0.465179f, 0.816884f, 1.617434f, + 0.964561f, 0.811168f, 0.685541f, 1.269441f, -0.294534f, -0.541415f, 0.148579f, 0.006120f, -0.047344f, + -0.034877f, 1.228496f, 0.766407f, 1.191577f, 0.830097f, 1.213856f, -1.697397f, -0.162200f, -0.216335f, + 0.082768f, 1.538109f, 1.455440f, 0.466843f, -0.675884f, -0.396112f, -0.230969f, 0.311936f, 0.850093f, + 0.895946f, 0.864577f, 0.906072f, 1.127087f, 0.915749f, 1.022470f, 1.086701f, 0.347097f, 0.115267f, + 0.269888f, 0.017932f, 0.837999f, 0.798699f, 0.830973f, 0.843566f, 0.524987f, -0.323668f, 0.796731f, + 0.882529f, 1.104285f, 0.707952f, 1.288781f, 1.066624f, -0.759169f, 1.253857f, -0.279808f, -0.810174f, + 0.635460f, 1.336810f, 1.461457f, -0.560630f, 0.345593f, 0.388281f, 0.011112f, 0.625432f, -0.202532f, + -0.952190f, 0.661665f, 1.290380f, -0.625566f, -0.330132f, 0.377751f, 1.393908f, 0.947332f, 0.567214f, + 0.597034f, 0.789381f, 1.108524f, 0.989273f, 0.896032f, 0.972095f, 0.451968f, -0.186156f, 0.864871f, + 1.008577f, 1.059174f, 1.005235f, 0.834800f, 0.881400f, -0.345810f, 0.538783f, -0.242229f, 0.765357f, + 0.363634f, 0.540277f, 0.489711f, 0.556296f, 0.791247f, 0.963361f, 0.900796f, 1.274361f, 1.440297f, + 0.639664f, -0.769517f, 2.005213f, -0.205800f, 0.462482f, 0.893398f, -0.179109f, -0.385072f, 0.698468f, + 0.656636f, -0.167324f, 0.646567f, 0.534505f, 1.234794f, 1.110618f, 1.271695f, 0.759512f, 0.229293f, + 0.147224f, 0.794720f, 1.099447f, 1.113528f, 1.058541f, -0.208087f, 0.316237f, -0.032344f, -0.114418f, + 0.540560f, 0.498906f, 0.465116f, 0.418016f, 0.482087f, 0.445022f, 0.453282f, 0.438177f, -0.006379f, + 0.377702f, -0.855888f, 1.042157f, 0.408202f, 0.339785f, 0.287742f, 0.420788f, 0.465379f, 1.007626f, + 1.001159f, 0.554656f, 0.459783f, 1.143811f, 0.339036f, 0.714696f, 0.691498f, 0.735108f, 1.053392f, + 0.778748f, 0.068571f, 0.274017f, 1.481772f, 1.693937f, 0.526139f, 0.909311f, 0.350476f, 0.954506f, + 0.197028f, 0.923411f, 0.045156f, 0.957155f, 0.714096f, 0.633157f, 0.789485f, 0.581167f, 0.845790f, + 0.829842f, 1.194247f, 0.971378f, 1.019175f, 0.907585f, 0.953225f, 0.951858f, 1.102269f, 1.018174f, + 0.902432f, 0.841796f, -0.858393f, -0.330711f, -0.469070f, 0.464267f, 1.114611f, -1.004036f, 1.620967f, + 0.329466f, 0.139467f, -0.470611f, 0.308757f, 1.016010f, 0.453660f, 1.595124f, 0.558440f, 1.023249f, + 0.601039f, 1.007291f, 0.995676f, 0.637742f, 0.970108f, 0.851145f, 0.582246f, 0.840873f, 0.433405f, + -0.009376f, -0.395102f, 0.229559f, 1.179632f, 0.217997f, 0.145108f, 1.614064f, 1.010146f, 0.887566f, + -1.011727f, 0.264498f, 0.152422f, 0.570916f, 0.925334f, -0.269998f, 0.860524f, 1.051678f, 1.007595f, + 0.941741f, 0.488055f, 0.245246f, 0.227135f, 0.066780f, -0.402708f, 1.265329f, 0.257161f, -0.447346f, + 0.493756f, -0.268568f, -0.217773f, -0.301152f, 0.475332f, 0.373900f, 0.446225f, 0.471130f, 0.663021f, + 1.000752f, -0.090537f, 0.673516f, 0.781955f, 0.128213f, 1.239298f, 0.764475f, 1.281084f, 0.902059f, + 0.278935f, 0.221142f, 0.160415f, -0.106214f, 0.210654f, 0.141437f, 0.198334f, 0.149962f, 0.565323f, + 0.050416f, 0.888878f, 0.074347f, 0.079686f, -0.363394f, 0.253592f, -0.311712f, -0.291973f, 0.133119f, + 1.097622f, 0.962363f, 0.796541f, 0.851959f, 0.628367f, 0.626313f, 0.646783f, 0.138650f, 0.510147f, + 1.394106f, 0.600274f, 1.246940f, 0.872970f, 0.275462f, -0.508244f, -0.408690f, 1.314789f, 0.349021f, + 1.545499f, 0.153658f, 0.231785f, 0.389777f, 0.378070f, 0.840290f, -1.853665f, 1.786896f, 0.104429f, + 0.181189f, 0.667719f, 0.567943f, 0.718873f, 0.244843f, 1.129714f, 0.881495f, 1.460520f, 1.995885f, + -0.395025f, 0.817815f, 1.208726f, -1.411448f, 0.606279f, -0.143777f, 0.296987f, 1.422581f, 0.720905f, + 1.279913f, -0.352711f, 0.658642f, 1.613478f, 0.339589f, -0.089663f, 0.243404f, 1.226488f, 0.467706f, + 0.797042f, 0.442854f, 1.121590f, -0.153407f, 1.431477f, 0.230959f, 1.437285f, -0.046937f, -1.527740f, + -0.272532f, 0.732910f, 0.766692f, 0.749836f, 0.778544f, 1.502128f, -0.240678f, 0.820989f, 1.461264f, + 0.744201f, 0.593997f, 0.769196f, 0.670758f, -0.186752f, 1.864102f, -0.563369f, 2.274148f, 1.338321f, + 0.830787f, -0.191057f, 0.642745f, 1.092864f, 1.217034f, 1.076530f, 0.948315f}; + + std::vector swish_data_nhwc = { 0.243866f, 1.365124f, 0.050220f, 0.657257f, -0.177689f, 1.365588f, 0.618877f, 0.492453f, 0.510088f, 0.537078f, 0.379677f, 1.201586f, 0.100271f, 0.029728f, -0.182035f, 0.077149f, 0.357653f, 0.016068f, 0.258221f, 0.062951f, 0.028085f, 0.331049f, -0.251691f, -0.092633f, 0.412580f, 0.762192f, -0.013831f, @@ -389,67 +611,220 @@ TEST(GroupNormTest, GroupNorm_128) { 0.098779f, 0.137334f, 1.757106f, -0.276652f, 1.146234f, 0.434016f, 0.136440f, 0.269671f, 0.128756f, -0.117812f, 0.533588f, 1.186146f, 0.443822f, 2.062000f, 0.421238f, 0.683523f}; + std::vector swish_data_nchw = { + 0.243866f, -0.159901f, -0.192410f, 0.064602f, 1.365124f, 0.309820f, 2.455177f, 1.225849f, 0.050220f, + -0.020019f, -0.054734f, 0.294918f, 0.657257f, 0.561637f, 0.646839f, 0.735547f, -0.177689f, -0.181556f, + 0.132996f, 0.133336f, 1.365588f, 2.420778f, 1.802949f, 1.254788f, 0.618877f, 0.105460f, 0.146142f, + 0.015386f, 0.492453f, 0.752824f, 0.352078f, 0.596943f, 0.510088f, 0.519554f, 0.509331f, 0.492345f, + 0.537078f, 0.823517f, 0.856461f, 0.581139f, 0.379677f, -0.029484f, 0.108424f, 0.186914f, 1.201586f, + 0.040870f, 0.392432f, 1.487454f, 0.100271f, 0.325333f, -0.278360f, 1.006534f, 0.029728f, 0.022692f, + -0.001154f, -0.005585f, -0.182035f, 0.603779f, 1.587304f, -0.001581f, 0.077149f, 0.019369f, 0.045678f, + 0.081065f, 0.357653f, 0.096304f, 0.327438f, 0.145732f, 0.016068f, 0.479364f, -0.067726f, 1.000125f, + 0.258221f, 0.425634f, -0.267492f, 0.206097f, 0.062951f, -0.126475f, -0.128642f, 0.149164f, 0.028085f, + 0.768537f, 0.448763f, 0.134332f, 0.331049f, 0.260306f, -0.061851f, -0.187918f, -0.251691f, -0.187456f, + 0.384811f, 1.363060f, -0.092633f, 0.928184f, 0.401312f, 0.694153f, 0.412580f, 0.475279f, 0.435012f, + 0.461047f, 0.762192f, 0.858428f, 0.522193f, 0.536204f, -0.013831f, 0.740447f, 1.305406f, 0.450521f, + 0.062333f, 0.062195f, 0.062175f, 0.062186f, -0.263020f, 0.602277f, -0.181835f, 0.055983f, 0.143483f, + 0.102570f, -0.167443f, 0.498477f, 0.733510f, -0.234669f, -0.190078f, 0.028087f, 0.029050f, 0.035395f, + -0.018481f, 0.003598f, 0.313013f, 0.110958f, 0.272698f, 0.016046f, 0.102139f, 0.141960f, 0.160295f, + 0.029994f, -0.015817f, -0.078762f, -0.009012f, 0.021779f, -0.090723f, 0.036662f, 0.653681f, 0.264239f, + 0.454239f, 0.950773f, 0.929582f, -0.033536f, 0.640686f, 1.071117f, -0.265990f, 0.872580f, 0.067457f, + 0.094387f, 0.746799f, -0.122234f, 0.799871f, 1.398649f, 0.879794f, 0.295709f, -0.008629f, 0.039996f, + 0.009796f, -0.056175f, 0.691892f, 0.914138f, 0.475701f, 1.027117f, 0.476392f, 0.049258f, 0.479070f, + -0.266875f, 0.976283f, 0.902623f, 1.076367f, -0.084465f, -0.217050f, -0.178574f, 0.153117f, 0.198578f, + 0.624266f, 0.579420f, 0.587422f, 0.650082f, 0.734605f, 0.830635f, 0.609541f, 0.757945f, 0.229296f, + 0.140073f, 0.054136f, 0.173615f, 0.564844f, 0.577730f, 0.549380f, 0.576280f, -0.253452f, 1.082880f, + 0.326523f, -0.204651f, 0.757935f, 0.406556f, 0.847322f, 1.137731f, 1.399714f, 0.741494f, -0.204150f, + 1.156216f, 0.714629f, -0.219945f, 1.050518f, 0.391983f, 0.080910f, 0.293266f, 0.073575f, 0.109964f, + -0.178318f, 0.017723f, 1.648380f, -0.173044f, -0.006105f, 1.507298f, -0.206833f, 1.268957f, 0.820346f, + 0.435470f, 0.600764f, 0.618676f, 0.640120f, 0.503657f, 0.602378f, 0.527688f, -0.171787f, 0.547762f, + 0.716126f, -0.255739f, 0.657118f, 0.814111f, 0.711085f, 0.731079f, -0.213635f, 0.172503f, 0.972323f, + 0.048719f, 0.238256f, 0.294135f, 0.375335f, 0.275437f, 0.850725f, 0.850672f, 0.471277f, 0.615220f, + 1.111010f, 0.277405f, -0.097483f, 1.936515f, 0.501217f, 0.861257f, -0.165546f, 0.015392f, 0.206920f, + -0.206513f, 1.120330f, -0.095172f, 0.581032f, 0.445763f, 0.359888f, 0.460849f, 0.398530f, 0.882337f, + 0.373787f, 0.479997f, 0.457656f, 0.577514f, 0.636029f, 0.508724f, 0.408295f, 0.086886f, 0.450923f, + -0.108577f, 0.359337f, 0.256654f, 0.269234f, 0.345855f, 0.245632f, 0.314115f, 0.341984f, 0.331264f, + 0.428173f, 0.855378f, 0.057805f, 0.699551f, 0.177226f, 0.237559f, 0.156379f, 0.238672f, 0.881711f, + 0.900973f, 0.056874f, 0.559207f, 0.212098f, 0.633758f, 0.207681f, 0.247724f, 0.378743f, 0.870852f, + 0.474008f, 0.603606f, 0.644037f, 0.002510f, 0.653584f, 0.871938f, 0.252370f, 0.373285f, 0.410021f, + 1.000926f, 0.254082f, 0.897667f, 0.208764f, 0.133824f, 0.307957f, 0.470606f, 0.638058f, 0.372154f, + 0.549116f, 0.673480f, 0.661132f, 0.840444f, 0.517441f, 0.785239f, 0.710716f, 0.780221f, 0.727826f, + 0.770623f, 0.671878f, 0.674734f, -0.137456f, 0.249797f, -0.151240f, -0.218730f, 1.370297f, 1.283304f, + -0.000166f, 0.547163f, -0.221845f, -0.253514f, -0.178658f, -0.185949f, 1.072585f, 0.863022f, 0.502916f, + 1.354473f, 0.632512f, 0.502989f, 0.415034f, 0.542996f, 0.387934f, 0.360029f, 0.649641f, 0.508608f, + 0.548196f, -0.038882f, 0.786735f, -0.011991f, 0.076070f, 1.912126f, 1.076488f, 0.230168f, -0.007803f, + 0.287736f, 0.889679f, 0.055314f, -0.256636f, 0.192089f, 0.241274f, -0.264336f, 0.768393f, 0.558143f, + 0.743397f, 0.433681f, 0.147922f, 0.998140f, 0.546106f, 0.713642f, 0.904965f, -0.044974f, 0.393839f, + 1.226827f, 0.222318f, -0.199037f, -0.169319f, -0.035779f, 0.306135f, 0.259179f, 0.236975f, 0.245986f, + -0.052967f, 0.197649f, -0.045887f, 0.969103f, 0.073060f, -0.126316f, 1.027756f, -0.094573f, 0.947734f, + -0.258402f, 0.494719f, 0.103365f, 0.085489f, -0.034082f, -0.032942f, 0.050215f, 0.866159f, 0.522367f, + 0.296938f, 0.516856f, -0.126290f, -0.145276f, -0.040656f, 0.383809f, 0.127472f, -0.032544f, 0.589695f, + 0.432058f, 0.615866f, 0.843271f, 0.848825f, 0.644802f, 0.237987f, 0.040511f, -0.052676f, -0.139653f, + 0.262488f, 0.272236f, 0.441086f, 0.169732f, 1.561563f, -0.112454f, 0.744888f, 0.362299f, 0.207543f, + 0.142219f, 0.943881f, 0.038345f, -0.003415f, 0.961279f, 0.888123f, 0.204724f, 0.141699f, -0.267405f, + -0.270732f, 0.907438f, 0.200946f, 0.573859f, 1.086931f, 0.013716f, -0.188076f, -0.178851f, 1.412803f, + 2.144154f, 0.227198f, 0.066942f, 0.381228f, 1.195574f, 0.516802f, -0.093427f, 0.769628f, -0.171073f, + 0.062015f, 1.220799f, 0.200465f, 1.521402f, 0.200143f, 0.051105f, 2.362491f, -0.233436f, 0.632902f, + 0.262543f, 0.711443f, 0.370009f, -0.125327f, 0.610777f, 0.770470f, 0.665922f, -0.277876f, -0.197959f, + -0.049002f, -0.202732f, 0.528298f, 0.498287f, 0.519488f, 0.536225f, 1.722697f, 0.191122f, -0.092087f, + 1.927784f, 0.558246f, 0.568889f, 0.427906f, 0.518755f, -0.202095f, 0.031705f, 1.368965f, -0.254002f, + 1.226985f, 0.380467f, 0.260506f, 0.083084f, 0.907526f, 0.678707f, 0.841215f, 0.872584f, 0.550561f, + -0.104546f, -0.080327f, 0.141080f, 0.499761f, 2.138265f, 0.384000f, 0.602158f, 0.195041f, 0.028526f, + 0.575932f, -0.143482f, 0.621209f, 0.643413f, 0.605480f, 0.635026f, 0.344486f, 0.059589f, -0.234058f, + 0.627989f, 0.962738f, -0.257335f, 1.135901f, 0.197799f, -0.055014f, 0.312156f, 0.498675f, 0.432245f, + 0.451459f, 0.625349f, 0.633730f, 0.754035f, 0.508984f, 0.574370f, 0.548760f, 0.462362f, 0.643412f, + 0.837068f, 0.655944f, 0.434440f, 0.327522f, 0.578454f, 0.010501f, 0.645105f, 0.022782f, 0.602389f, + 0.332705f, 0.124847f, 0.415858f, -0.244295f, 0.636554f, 0.210367f, 0.053347f, 0.024967f, -0.011600f, + -0.003652f, 0.883625f, 0.094167f, 0.005264f, -0.157717f, -0.083251f, 0.221779f, -0.029940f, 0.177076f, + 0.221916f, 0.317751f, 0.314914f, 0.365097f, 0.425394f, 0.632969f, 0.490896f, 0.265550f, 0.154676f, + 0.151727f, -0.263180f, -0.131768f, 0.123024f, -0.189286f, -0.093822f, -0.109161f, 0.294614f, -0.098691f, + 0.537700f, 0.376140f, 0.168252f, 0.091604f, -0.161157f, -0.159695f, 0.546489f, 0.564490f, 1.395845f, + 1.104433f, 0.263568f, 0.686118f, 0.238550f, -0.150630f, 0.456214f, 0.410026f, 0.506708f, 0.424806f, + 0.687772f, 0.894037f, 0.514549f, 0.560069f, -0.082351f, 0.866797f, 0.801729f, -0.178628f, 0.062297f, + 0.062251f, 0.062210f, 0.062287f, 0.929695f, -0.152669f, 0.298229f, 0.277207f, -0.036025f, -0.121153f, + 0.064261f, -0.075345f, -0.162895f, 1.069637f, -0.046150f, 0.355371f, -0.014807f, 0.044176f, 0.013210f, + -0.026496f, 0.024873f, 0.359187f, 0.359935f, 0.217098f, 0.034856f, 0.058271f, 0.127668f, 0.137113f, + 0.000690f, -0.021974f, -0.044238f, -0.079923f, -0.029181f, 0.619223f, 0.480672f, -0.002627f, 0.159995f, + 0.751405f, 0.924329f, 0.999099f, -0.254131f, 1.416720f, -0.222613f, 0.285733f, 0.566570f, 1.349653f, + 0.698375f, 0.561619f, 0.455867f, 0.990985f, -0.125735f, -0.199164f, 0.079798f, 0.003070f, -0.023112f, + -0.017134f, 0.950309f, 0.523259f, 0.913967f, 0.578059f, 0.935859f, -0.262766f, -0.074537f, -0.096513f, + 0.043096f, 1.266156f, 1.180120f, 0.286938f, -0.227895f, -0.159335f, -0.102207f, 0.180099f, 0.595564f, + 0.636225f, 0.608330f, 0.645301f, 0.851290f, 0.654005f, 0.751979f, 0.812592f, 0.203369f, 0.060952f, + 0.153044f, 0.009046f, 0.584960f, 0.550860f, 0.578823f, 0.589835f, 0.329856f, -0.135870f, 0.549166f, + 0.624253f, 0.829387f, 0.474291f, 1.010328f, 0.793519f, -0.242043f, 0.975459f, -0.120458f, -0.249415f, + 0.415417f, 1.058707f, 1.186345f, -0.203734f, 0.202362f, 0.231365f, 0.005587f, 0.407439f, -0.091046f, + -0.265132f, 0.436457f, 1.011931f, -0.218020f, -0.138064f, 0.224131f, 1.116821f, 0.682626f, 0.361950f, + 0.385073f, 0.542856f, 0.833448f, 0.721125f, 0.636303f, 0.705290f, 0.276201f, -0.084439f, 0.608590f, + 0.739027f, 0.786472f, 0.735919f, 0.582164f, 0.623249f, -0.143303f, 0.340258f, -0.106517f, 0.522368f, + 0.214515f, 0.341388f, 0.303639f, 0.353579f, 0.544456f, 0.697275f, 0.640568f, 0.995898f, 1.164481f, + 0.418773f, -0.243616f, 1.767281f, -0.092349f, 0.283780f, 0.633947f, -0.081556f, -0.155917f, 0.466470f, + 0.432398f, -0.076679f, 0.424301f, 0.337023f, 0.956541f, 0.835456f, 0.993236f, 0.517416f, 0.127733f, + 0.079021f, 0.547438f, 0.824758f, 0.838249f, 0.785873f, -0.093257f, 0.182914f, -0.015911f, -0.053940f, + 0.341603f, 0.310421f, 0.285687f, 0.252067f, 0.298046f, 0.271221f, 0.277146f, 0.266335f, -0.003180f, + 0.224097f, -0.255225f, 0.770431f, 0.245189f, 0.198482f, 0.164428f, 0.254018f, 0.285878f, 0.738142f, + 0.732134f, 0.352326f, 0.281830f, 0.867442f, 0.197982f, 0.479874f, 0.460745f, 0.496878f, 0.781012f, + 0.533761f, 0.035461f, 0.155663f, 1.207408f, 1.430939f, 0.330722f, 0.648210f, 0.205636f, 0.689173f, + 0.108188f, 0.660919f, 0.023088f, 0.691594f, 0.479376f, 0.413581f, 0.542946f, 0.372724f, 0.591785f, + 0.577837f, 0.916584f, 0.704632f, 0.748902f, 0.646659f, 0.688003f, 0.686755f, 0.827456f, 0.747968f, + 0.642034f, 0.588283f, -0.255522f, -0.138260f, -0.180515f, 0.285072f, 0.839288f, -0.269231f, 1.353391f, + 0.191627f, 0.074588f, -0.180937f, 0.178024f, 0.745949f, 0.277417f, 1.326083f, 0.355219f, 0.752707f, + 0.388207f, 0.737830f, 0.727050f, 0.417238f, 0.703465f, 0.596488f, 0.373560f, 0.587475f, 0.262941f, + -0.004666f, -0.159025f, 0.127896f, 0.902279f, 0.120832f, 0.077809f, 1.346089f, 0.740486f, 0.628741f, + -0.269769f, 0.149638f, 0.082008f, 0.364801f, 0.662657f, -0.116884f, 0.604751f, 0.779395f, 0.738113f, + 0.677536f, 0.302422f, 0.137584f, 0.126410f, 0.034505f, -0.161350f, 0.986884f, 0.145023f, -0.174461f, + 0.306618f, -0.116360f, -0.097077f, -0.128073f, 0.293111f, 0.221499f, 0.272082f, 0.290052f, 0.437553f, + 0.731756f, -0.043221f, 0.446063f, 0.536501f, 0.068210f, 0.961003f, 0.521620f, 1.002620f, 0.641700f, + 0.158794f, 0.122747f, 0.086627f, -0.050289f, 0.116380f, 0.075711f, 0.108969f, 0.080593f, 0.360497f, + 0.025843f, 0.629911f, 0.038555f, 0.041430f, -0.149042f, 0.142787f, -0.131760f, -0.124825f, 0.070983f, + 0.823012f, 0.696361f, 0.549003f, 0.597205f, 0.409770f, 0.408138f, 0.424474f, 0.074123f, 0.318761f, + 1.117023f, 0.387609f, 0.968585f, 0.615761f, 0.156582f, -0.190899f, -0.163160f, 1.036466f, 0.204659f, + 1.273897f, 0.082720f, 0.129264f, 0.232396f, 0.224349f, 0.586964f, -0.251066f, 1.530559f, 0.054939f, + 0.098779f, 0.441357f, 0.362511f, 0.483341f, 0.137334f, 0.853822f, 0.623333f, 1.185376f, 1.757106f, + -0.159001f, 0.567378f, 0.930808f, -0.276652f, 0.392318f, -0.066730f, 0.170383f, 1.146234f, 0.485029f, + 1.001448f, -0.145573f, 0.434016f, 1.345469f, 0.198351f, -0.042823f, 0.136440f, 0.948325f, 0.287564f, + 0.549434f, 0.269671f, 0.845997f, -0.070832f, 1.155390f, 0.128756f, 1.161375f, -0.022918f, -0.272434f, + -0.117812f, 0.495040f, 0.523501f, 0.509246f, 0.533588f, 1.228578f, -0.105927f, 0.570132f, 1.186146f, + 0.504504f, 0.382702f, 0.525628f, 0.443822f, -0.084682f, 1.613891f, -0.204372f, 2.062000f, 1.060236f, + 0.578661f, -0.086430f, 0.421238f, 0.818468f, 0.938992f, 0.802915f, 0.683523f}; + // Test float16, without activation int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); - if (enable_cuda || enable_rocm || enable_dml) { - OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); - test.AddAttribute("epsilon", 1e-05f); - test.AddAttribute("groups", 32); - test.AddAttribute("activation", 0); + std::array channels_last_values = {-1, 0, 1}; - test.AddInput("X", dims, ToFloat16(input_data)); - test.AddInput("gamma", {C}, gamma_data); + for (const int channels_last : channels_last_values) { + if (enable_cuda || enable_rocm || enable_dml) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 0); - test.AddInput("beta", {C}, beta_data); + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } - constexpr float rel_error = 0.0f; - constexpr float abs_error = 0.02f; - test.AddOutput("Y", dims, ToFloat16(norm_data), false, rel_error, abs_error); + if (channels_last == 0) { + test.AddInput("X", dims_nchw, ToFloat16(input_data_nchw)); - std::vector> execution_providers; - if (enable_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - if (enable_dml) { - execution_providers.push_back(DefaultDmlExecutionProvider()); - } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nchw, ToFloat16(norm_data_nchw), false, rel_error, abs_error); + } else { + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); - // Test float32, with activation - enable_cuda = HasCudaEnvironment(0); - if (enable_cuda || enable_rocm || enable_dml) { - OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); - test.AddAttribute("epsilon", 1e-05f); - test.AddAttribute("groups", 32); - test.AddAttribute("activation", 1); - - test.AddInput("X", dims, input_data); - test.AddInput("gamma", {C}, gamma_data); - test.AddInput("beta", {C}, beta_data); - - constexpr float rel_error = 0.0f; - constexpr float abs_error = 0.01f; - test.AddOutput("Y", dims, swish_data, false, rel_error, abs_error); - - std::vector> execution_providers; - if (enable_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + } + + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + if (enable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + + if (!execution_providers.empty()) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } - if (enable_dml) { - execution_providers.push_back(DefaultDmlExecutionProvider()); + + // Test float32, with activation + enable_cuda = HasCudaEnvironment(0); + if (enable_cuda || enable_rocm || enable_dml) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 1); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + if (channels_last == 0) { + test.AddInput("X", dims_nchw, input_data_nchw); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.01f; + test.AddOutput("Y", dims_nchw, swish_data_nchw, false, rel_error, abs_error); + } else { + test.AddInput("X", dims_nhwc, input_data_nhwc); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.01f; + test.AddOutput("Y", dims_nhwc, swish_data_nhwc, false, rel_error, abs_error); + } + + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + if (enable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + + if (!execution_providers.empty()) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } }