From 3ecb01233789d4057a902d3e559ac6a9b410cbed Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:44:40 -0700 Subject: [PATCH] [CPU EP] Add blocked quantization to DequantizeLinear op kernel (#20901) ### Description Added blocked quantization to DequantizeLinear op kernel. All existing [input types and output types](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftdequantizelinear) are supported. All axes are supported. The implementation in the PR is naive - single thread and scalar instructions. Multi-threading and vector instructions are planned in the future based on the needs. ### Motivation and Context onnx introduced blocked quantization in opset 21 for [DequantizeLinear](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftdequantizelinear). This PR adds the spec support in onnx runtime. --- .../cpu/quantization/quantize_linear.cc | 360 ++++++--- .../cpu/tensor/quantize_linear_test.cc | 755 ++++++++++++++++++ 2 files changed, 1024 insertions(+), 91 deletions(-) diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 05dea2a05c97b..91e21b3690b27 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -25,10 +25,7 @@ class DequantizeLinear final : public OpKernel { block_size_ = 0; } - // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. - if (block_size_ != 0) { - ORT_THROW("DequantizeLinear does not yet support the 'block_size' attribute."); - } + ORT_ENFORCE(block_size_ >= 0, "'block_size' must be non-negative."); } Status Compute(OpKernelContext* context) const override; @@ -71,31 +68,55 @@ static void PrepareForQDQ(const TensorShape& input_shape, const Tensor& scale, const Tensor* zero_point_ptr, int64_t axis, - int64_t& quant_block_count, // A "quant block" is a block of elems with the same scale/zp - int64_t& axis_dim_val, - int64_t& quant_block_size) { + int64_t quant_block_size, + int64_t& process_block_count, + int64_t& broadcast_dim, + int64_t& process_block_size) { if (IsScalarOr1ElementVector(&scale)) { // per-tensor QuantizeLinear/DequantizeLinear - quant_block_count = 1; - axis_dim_val = 1; - quant_block_size = static_cast(input_shape.Size()); + process_block_count = 1; + broadcast_dim = 1; + process_block_size = static_cast(input_shape.Size()); // enforce that zero point are scalars ORT_ENFORCE(zero_point_ptr == nullptr || IsScalarOr1ElementVector(zero_point_ptr), "x_zero_point must be null or a scalar or 1D tensor or size 1."); - } else { // per-channel QuantizeLinear/DequantizeLinear + ORT_ENFORCE(quant_block_size == 0, "block_size must be 0 for per-tensor quantization."); + } else { // per-axis or blocked QuantizeLinear/DequantizeLinear const int64_t axis_no_neg = HandleNegativeAxis(axis, input_shape.NumDimensions()); - quant_block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); - axis_dim_val = input_shape[onnxruntime::narrow(axis_no_neg)]; - quant_block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); + process_block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); + broadcast_dim = input_shape[onnxruntime::narrow(axis_no_neg)]; + process_block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); // if an axis was specified, ensure the scale and zero point are compatible - ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == axis_dim_val, - "scale must be 1D tensor with size ", - axis_dim_val); - ORT_ENFORCE(zero_point_ptr == nullptr || - (zero_point_ptr->Shape().NumDimensions() == 1 && zero_point_ptr->Shape()[0] == axis_dim_val), - "x_zero_point must be null or 1D tensor with size ", - axis_dim_val); + if (quant_block_size) { // blocked quantization + ORT_ENFORCE(scale.Shape().NumDimensions() == input_shape.NumDimensions(), + "x_scale and x must have the same rank for blocked quantization"); + ORT_ENFORCE(zero_point_ptr == nullptr || zero_point_ptr->Shape().NumDimensions() == input_shape.NumDimensions(), + "x_zero_point must be null or have the same rank as x for blocked quantization"); + + for (size_t i = 0, ndim = input_shape.NumDimensions(); i < ndim; ++i) { + if (i == SafeInt(axis_no_neg)) { + ORT_ENFORCE(scale.Shape()[i] == (input_shape[i] + quant_block_size - 1) / quant_block_size, + "x_scale must be ceil(Di/block_size) on the quantize axis i for blocked quantization"); + } else { + ORT_ENFORCE(scale.Shape()[i] == input_shape[i], + "x_scale and x must have the same shape despite the quantize axis for blocked quantization"); + } + + if (zero_point_ptr) { + ORT_ENFORCE(zero_point_ptr->Shape()[i] == scale.Shape()[i], + "x_zero_point and x_scale must have the same shape for blocked quantization"); + } + } + } else { // per-axis quantization + ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == broadcast_dim, + "For per axis quantization, scale must be 1D tensor with size ", + broadcast_dim); + ORT_ENFORCE(zero_point_ptr == nullptr || (zero_point_ptr->Shape().NumDimensions() == 1 && + zero_point_ptr->Shape()[0] == broadcast_dim), + "For per axis quantization, x_zero_point must be null or 1D tensor with size ", + broadcast_dim); + } } } @@ -244,66 +265,198 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) +template +struct DequantizeLinearApply; + +// The dimensions before quantize axis and after quantize axis can be flattened. +// After flattening, the tensor can be represented by a rank-3 tensor. +// If the quantization happens on the first or last axis, the flattened tensor is +// effectively rank-2. +// For per tensor quantization, the tensor is effectively rank-1. template -struct DequantizeLinearApply { - void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, OutT* output, - const T* zero_point) { - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { - auto zp = zero_point ? static_cast(zero_point[bd]) : 0; - auto sc = static_cast(scale[bd]); - for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { +struct DequantizeLinearApply { + /** + * @brief Calculate per-tensor/layer or per-axis quantization of DequantizeLinear on the + * flattened tensors. + * @param[in] M size of dimensions before the quantize axis + * @param[in] K dimension on the quantize axis + * @param[in] N size of dimensions after the quantize axis + * @param[in] input 1D array of flattened [D0, ..., Di, ..., Dn] + * @param[in] scale scalar for per-tensor/layer quantization and 1D array [Di] + * for per-axis quantization. i is the quantize axis. + * @param[out] output same shape as input + * @param[in] zero_point same shape as scale + */ + void op(size_t M, size_t K, size_t N, const T* input, + const OutT* scale, OutT* output, const T* zero_point) { + for (size_t m = 0; m < M; m++) { + for (size_t k = 0; k < K; k++) { + auto zp = zero_point ? static_cast(zero_point[k]) : 0; + auto sc = static_cast(scale[k]); + for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } } } } + + /** + * @brief Calculate blocked quantization of DequantizeLinear on the flattened tensors. + * TODO(fajin): add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. + * @param[in] M size of dimensions before the quantize axis + * @param[in] K dimension of the quantize axis + * @param[in] N size of dimensions after the quantize axis + * @param[in] quant_block_size quantize block size along the quantize axis + * @param[in] input 1D array of flattened [D0, ..., Di, ..., Dn] + * @param[in] scale 1D array of flattened [D0, ..., ceil(Di/quant_block_size), ..., Dn]. + * i is the quantize axis. + * @param[out] output same shape as input + * @param[in] zero_point same shape as scale + */ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, + const T* input, const OutT* scale, OutT* output, const T* zero_point) { + if (zero_point) { + for (size_t m = 0; m < M; m++) { + for (size_t bd = 0; bd < K; bd += quant_block_size) { + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { + // within the quantize block, the zero point and scale are the same. + for (size_t bs = 0; bs < N; bs++) { + auto zp = static_cast(zero_point[bs]); + auto sc = static_cast(scale[bs]); + *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); + } + } + + // move to the next quantize block + zero_point += N; + scale += N; + } + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t bd = 0; bd < K; bd += quant_block_size) { + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { + // within the quantize block, the zero point and scale are the same. + for (size_t bs = 0; bs < N; bs++) { + auto sc = static_cast(scale[bs]); + *output++ = static_cast(static_cast(static_cast(*input++)) * sc); + } + } + + // move to the next quantize block + scale += N; + } + } + } + } }; -#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ - OutT* output, const T* zero_point) { \ - size_t input_index = 0; \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; \ - auto sc = static_cast(scale[bd]); \ - for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { \ - size_t input_i = input_index >> 1; \ - size_t input_j = input_index & 0x1; \ - int32_t val = static_cast(input[input_i].GetElem(input_j)); \ - *output++ = static_cast(static_cast(val - zp) * sc); \ - input_index += 1; \ - } \ - } \ - } \ - assert(input_index == static_cast(N * axis_dim_val * quant_block_size)); \ - } \ - }; +template +struct DequantizeLinearApply { + // per-tensor/layer or per-axis quantization + void op(size_t M, size_t K, size_t N, + const T* input, const OutT* scale, OutT* output, const T* zero_point) { + size_t input_index = 0; + + for (size_t m = 0; m < M; m++) { + for (size_t bd = 0; bd < K; bd++) { + size_t bd_i = bd >> 1; /*bd / 2*/ + size_t bd_j = bd & 0x1; /*bd % 2*/ + auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; + auto sc = static_cast(scale[bd]); + + for (size_t bs = 0; bs < N; bs++) { + size_t input_i = input_index >> 1; + size_t input_j = input_index & 0x1; + int32_t val = static_cast(input[input_i].GetElem(input_j)); + *output++ = static_cast(static_cast(val - zp) * sc); + input_index += 1; + } + } + } + + assert(input_index == M * K * N); + } + + // Blocked quantization + // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. + void op(size_t M, size_t K, size_t N, size_t quant_block_size, + const T* input, const OutT* scale, OutT* output, const T* zero_point) { + size_t input_index = 0; + + if (zero_point) { + size_t zp_index = 0; + + for (size_t n = 0; n < M; n++) { + for (size_t bd = 0; bd < K; bd += quant_block_size) { + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { + auto q_zp_index = zp_index; + for (size_t bs = 0; bs < N; ++bs, ++input_index, ++q_zp_index) { + auto zp = static_cast(zero_point[q_zp_index >> 1].GetElem(q_zp_index & 0x1)); + auto sc = static_cast(scale[bs]); + + int32_t val = static_cast(input[input_index >> 1].GetElem(input_index & 0x1)); + *output++ = static_cast(static_cast(val - zp) * sc); + } + } + + scale += N; + zp_index += N; + } + } + } else { + for (size_t n = 0; n < M; n++) { + for (size_t bd = 0; bd < K; bd += quant_block_size) { + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { + for (size_t bs = 0; bs < N; ++bs, ++input_index) { + auto sc = static_cast(scale[bs]); + + int32_t val = static_cast(input[input_index >> 1].GetElem(input_index & 0x1)); + *output++ = static_cast(static_cast(val) * sc); + } + } + + scale += N; + } + } + } -DEQUANTIZE_LINEAR_APPLY_INT4(Int4x2); -DEQUANTIZE_LINEAR_APPLY_INT4(UInt4x2); + assert(input_index == M * K * N); + } +}; #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ - OutT* output, const T*) { \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < static_cast(quant_block_size); bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -323,11 +476,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto& x_shape = x.Shape(); auto& y = *ctx->Output(0, x_shape); - int64_t N; - int64_t axis_dim_val; - int64_t quant_block_size; + int64_t process_block_count; + int64_t broadcast_dim; + int64_t process_block_size; - PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, axis_dim_val, quant_block_size); + PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, block_size_, + process_block_count, broadcast_dim, process_block_size); const T* zero_point = x_zero_point ? x_zero_point->Data() : nullptr; @@ -345,15 +499,38 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); + constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); float* output = y.MutableData(); - DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); + if (block_size_) { + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + input, scale, output, zero_point); + } else { + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + input, scale, output, zero_point); + } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); MLFloat16* output = y.MutableData(); - DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); + if (block_size_) { + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + input, scale, output, zero_point); + } else { + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + input, scale, output, zero_point); + } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); } else { @@ -524,14 +701,14 @@ void ParQuantizeLinear(const InputType* Input, } template -void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, - int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { - ParQuantizeLinear(input, output, static_cast(quant_block_size), scale[bd], bd, zero_point, saturate, - ctx->GetOperatorThreadPool()); - input += quant_block_size; - output += quant_block_size; +void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, + int64_t process_block_count, int64_t broadcast_dim, int64_t process_block_size, bool saturate) { + for (size_t n = 0; n < static_cast(process_block_count); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + ParQuantizeLinear(input, output, static_cast(process_block_size), scale[bd], bd, zero_point, + saturate, ctx->GetOperatorThreadPool()); + input += process_block_size; + output += process_block_size; } } } @@ -611,20 +788,21 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { const auto& x_shape = x.Shape(); auto& y = *ctx->Output(0, x_shape); - int64_t N; - int64_t axis_dim_val; - int64_t quant_block_size; - PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, axis_dim_val, quant_block_size); + int64_t process_block_count; + int64_t broadcast_dim; + int64_t process_block_size; + PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, block_size_, + process_block_count, broadcast_dim, process_block_size); const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr; T* output = y.MutableData(); if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, axis_dim_val, - quant_block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, + process_block_count, broadcast_dim, process_block_size, saturate_); } else if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, - axis_dim_val, quant_block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, + process_block_count, broadcast_dim, process_block_size, saturate_); } else { ORT_THROW("Unsupported input type."); } diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 5eeda5a3b8949..054dcfc75b92e 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -794,5 +794,760 @@ TEST(QuantizeLinearOpMLFloat16Test, Float8) { #endif +namespace blocked_dequantization { + +template +void DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("DequantizeLinear", 21); + std::vector dims{2, 4}; + std::vector x_scale, y; + std::vector x, x_zero_point; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "DequantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = 2 * zero_point_block_count; i < n; ++i) x_zero_point.push_back(0); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tout(2.0f)); + for (int i = 0; i < 8; ++i) { + x.push_back(i); + y.push_back(Tout(static_cast(i) * 2.0f)); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", {2, scale_block_count}, x_scale); + test.AddInput("x_zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +template +void DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("DequantizeLinear", 21); + std::vector dims{2, 4}; + std::vector x_scale, y; + std::vector x, x_zero_point; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "DequantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = zero_point_block_count; i < n; ++i) x_zero_point.push_back(Tin(0, 0)); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tout(2.0f)); + for (int i = 0; i < 8; ++i) { + if (i & 1) x.push_back(Tin(i - 1, i)); + y.push_back(Tout(static_cast(i) * 2.0f)); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", {2, scale_block_count}, x_scale); + test.AddInput("x_zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +template +void DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("DequantizeLinear", 21); + std::vector dims{2, 4}; + std::vector x_scale, y; + std::vector x, x_zero_point; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "DequantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = 2 * zero_point_block_count; i < n; i++) x_zero_point.push_back(Tin(0.0f)); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tout(2.0f)); + for (int i = 0; i < 8; ++i) x.push_back(Tin(static_cast(i))); + for (int i = 0; i < 8; ++i) y.push_back(Tout(static_cast(i) * 2.0f)); + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", {2, scale_block_count}, x_scale); + test.AddInput("x_zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +// test negative block size fail +TEST(DequantizeLinearOp21BlockedTest, NagativeBlockSize_Int) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-2, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-2, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-3, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-3, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-4, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-4, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-5, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-5, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-6, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(DequantizeLinearOp21BlockedTest, NagativeBlockSize_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-2, 2, 2); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-3, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-4, 2, 2); + } + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-5, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-6, 2, 2); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + } +} +#endif + +// test block size incompatible with x_scale shape fail +TEST(DequantizeLinearOp21BlockedTest, IncompatibleBlockSizeWithX_Int) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(DequantizeLinearOp21BlockedTest, IncompatibleBlockSizeWithX_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + } +} +#endif + +// test x_scale vs. x_zero_point shape incompatible fail +TEST(DequantizeLinearOp21BlockedTest, ScaleShapeUnmatchZeroPoint_Int) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(DequantizeLinearOp21BlockedTest, ScaleShapeUnmatchZeroPoint_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + DequantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + } +} +#endif + +// test DQ with blocked quantization succeed +template +void DequantizeLinearOp21BlockedTest_Int4_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& x_scale_, + std::vector& x_zero_point_, + std::vector& y_) { + OpTester test("DequantizeLinear", 21); + std::vector x_scale_shape; + std::vector x_scale, y; + std::vector x, x_zero_point; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !x_zero_point_.empty(); + + for (auto v : y_) y.push_back(Tout(v)); + for (auto v : x_scale_) x_scale.push_back(Tout(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + x_scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + + size_t i = 0, n = x_.size(); + for (; i < n - 1; i += 2) x.push_back(Tin(x_[i], x_[i + 1])); + if (i < n) x.push_back(Tin(x_[i], 0xF)); + + if (use_zero_point) { + i = 0, n = x_zero_point_.size(); + for (; i < n - 1; i += 2) x_zero_point.push_back(Tin(x_zero_point_[i], x_zero_point_[i + 1])); + if (i < n) x_zero_point.push_back(Tin(x_zero_point_[i], 0xF)); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", x_scale_shape, x_scale); + if (use_zero_point) test.AddInput("x_zero_point", x_scale_shape, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +template +void DequantizeLinearOp21BlockedTest_Int_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& x_scale_, + std::vector& x_zero_point_, + std::vector& y_) { + OpTester test("DequantizeLinear", 21); + std::vector x_scale_shape; + std::vector x_scale, y; + std::vector x, x_zero_point; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !x_zero_point_.empty(); + + for (auto v : y_) y.push_back(Tout(v)); + for (auto v : x_scale_) x_scale.push_back(Tout(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + x_scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + for (auto v : x_) x.push_back(v); + if (use_zero_point) + for (auto v : x_zero_point_) x_zero_point.push_back(v); + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", x_scale_shape, x_scale); + if (use_zero_point) test.AddInput("x_zero_point", x_scale_shape, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +template +void DequantizeLinearOp21BlockedTest_Float8_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& x_scale_, + std::vector& x_zero_point_, + std::vector& y_) { + OpTester test("DequantizeLinear", 21); + std::vector x_scale_shape; + std::vector x_scale, y; + std::vector x, x_zero_point; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !x_zero_point_.empty(); + + for (auto v : y_) y.push_back(Tout(v)); + for (auto v : x_scale_) x_scale.push_back(Tout(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + x_scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + + for (auto v : x_) x.push_back(Tin(static_cast(v))); + if (use_zero_point) { + for (auto v : x_zero_point_) x_zero_point.push_back(Tin(static_cast(v))); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("x_scale", x_scale_shape, x_scale); + if (use_zero_point) test.AddInput("x_zero_point", x_scale_shape, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_FirstAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{14.0, 24.0, -17.5, -4.0, 6.0, 8.0, -3.5, 0.0, 2.0, 8.0, -10.5, -4.0, 10.0, 24.0, -24.5, 8.0}; + std::vector y_3{14.0, 24.0, -17.5, -4.0, 6.0, 8.0, -3.5, 0.0, -2.0, -8.0, 10.5, 4.0, 10.0, 24.0, -24.5, 8.0}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_FirstAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -3, -1, 0, 2, 4, 7}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{2.0, 8.0, -7.0, -3, -6.0, -8.0, 7.0, 1, 2.0, 0, 3.5, 3.0, 10.0, 16.0, -10.5, 15}; + std::vector y_3{2.0, 8.0, -7.0, -3, -6.0, -8.0, 7.0, 1, -14.0, -24, 21, 5, 10.0, 16.0, -10.5, 15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_MiddleAxis) { + std::vector zero_point{}; + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, 8}; + std::vector y_3{14, 24, 10, 16, 6, 8, -3.5, 0, 2, 8, 6, 16, 10, 24, -24.5, 8}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_MiddleAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -3, -1, 0, 2, 4, 7}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{2, 8, -2, 0, 0, -1, 7, 1, 2, 0, 6, 8, -3.5, 1, -10.5, 15}; + std::vector y_3{2, 8, -2, 0, -6, -8, 7, 1, 2, 0, 6, 8, 10, 16, -10.5, 15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_LastAxis) { + std::vector zero_point{}; + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{14, 12, 20, 16, -10.5, -7, -1, 0, 2, 4, 12, 16, -17.5, -21, -7, 8}; + std::vector y_3{14, 12, 10, 16, -10.5, -7, -3.5, 0, 2, 4, 6, 16, -17.5, -21, -24.5, 8}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_LastAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -3, -1, 0, 2, 4, 7}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_2{2, 0, 4, 0, 0, 3.5, 0, 1, 2, 4, 4, 8, -3.5, -7, 0, 15}; + std::vector y_3{2, 0, -2, 0, 0, 3.5, 7, 1, 2, 4, 6, 8, -3.5, -7, -10.5, 15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_FirstAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{0, -4, 7, 3, -8, -20, 21, 7, 16, 36, -35, -11, 24, 52, -49, -15}; + std::vector y_3{0, -4, 7, 3, -8, -20, 21, 7, -16, -36, 35, 11, 24, 52, -49, -15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_FirstAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 13, 5, 11, 6}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{4, -4, 3.5, -6, -4, -20, 17.5, -2, -10, 16, 3.5, -5, -2, 32, -10.5, -9}; + std::vector y_3{4, -4, 3.5, -6, -4, -20, 17.5, -2, -12, -36, 31.5, 2, -2, 32, -10.5, -9}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_MiddleAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15}; + std::vector y_3{0, -4, -4, -12, -8, -20, 21, 7, 16, 36, 20, 44, 24, 52, -49, -15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_MiddleAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 13, 5, 11, 6}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9}; + std::vector y_3{4, -4, 0, -12, -4, -20, 17.5, -2, -10, 16, -6, 24, -2, 32, -10.5, -9}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_LastAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{0, -2, -8, -12, 14, 17.5, 6, 7, 16, 18, 40, 44, -42, -45.5, -14, -15}; + std::vector y_3{0, -2, -4, -12, 14, 17.5, 21, 7, 16, 18, 20, 44, -42, -45.5, -49, -15}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); +} + +TEST(DequantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_LastAxis) { + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 13, 5, 11, 6}; + std::vector x{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_2{4, 2, -8, -12, 10.5, 14, -3, -2, -10, -8, 20, 24, -3.5, -7, -8, -9}; + std::vector y_3{4, 2, 0, -12, 10.5, 14, 17.5, -2, -10, -8, -6, 24, -3.5, -7, -10.5, -9}; + + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int4_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Int_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(DequantizeLinearOp21BlockedTest, Float8_NoZeroPoint_FirstAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + std::vector y_2{14.0, 24.0, -17.5, -4.0, 6.0, 8.0, -3.5, 0.0, 2.0, 8.0, -10.5, -4.0, 10.0, 24.0, -24.5, -8.0}; + std::vector y_3{14.0, 24.0, -17.5, -4.0, 6.0, 8.0, -3.5, 0.0, -2.0, -8.0, 10.5, 4.0, 10.0, 24.0, -24.5, -8.0}; + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 2, 2}, 0, 3, x, x_scale, zero_point, y_3); + } +} + +TEST(DequantizeLinearOp21BlockedTest, Float8_NoZeroPoint_MiddleAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector zero_point{}; + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + std::vector y_2{14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, -8}; + std::vector y_3{14, 24, 10, 16, 6, 8, -3.5, 0, 2, 8, 6, 16, 10, 24, -24.5, -8}; + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 4, 2}, 1, 3, x, x_scale, zero_point, y_3); + } +} + +TEST(DequantizeLinearOp21BlockedTest, Float8_NoZeroPoint_LastAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector zero_point{}; + std::vector x_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + std::vector y_2{14, 12, 20, 16, -10.5, -7, -1, 0, 2, 4, 12, 16, -17.5, -21, -7, -8}; + std::vector y_3{14, 12, 10, 16, -10.5, -7, -3.5, 0, 2, 4, 6, 16, -17.5, -21, -24.5, -8}; + + if (enable_cpu || enable_cuda) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + } + if (enable_cpu) { + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 2, x, x_scale, zero_point, y_2); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed({2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + DequantizeLinearOp21BlockedTest_Float8_Succeed( + {2, 2, 4}, 2, 3, x, x_scale, zero_point, y_3); + } +} +#endif +} // namespace blocked_dequantization + } // namespace test } // namespace onnxruntime