From c91dca13a084b540a781da9bf4e93b3fa802a825 Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Tue, 12 Jul 2022 10:03:19 -0700 Subject: [PATCH] QuantizeV2 enabled for bfloat16 input tensor. --- .../api_def/base_api/api_def_QuantizeV2.pbtxt | 8 +- .../core/kernels/mkl/mkl_quantize_op.cc | 18 +++ .../core/kernels/mkl/mkl_quantize_op_test.cc | 106 +++++++++++++----- tensorflow/core/kernels/quantize_op.cc | 69 +++++++++--- tensorflow/core/kernels/quantize_op_test.cc | 71 ++++++++++-- tensorflow/core/ops/array_ops.cc | 3 +- tensorflow/core/ops/mkl_array_ops.cc | 3 +- 7 files changed, 218 insertions(+), 60 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt index fbdc55113f661b..b2b902c7f15ecf 100644 --- a/tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt @@ -42,7 +42,13 @@ If the `axis` attribute is specified, this will be a 1-D tensor whose size matches the `axis` dimension of the input and output tensors. END } - summary: "Quantize the \'input\' tensor of type float to \'output\' tensor of type \'T\'." + attr { + name: "dtype" + description: <GetAttr("axis", &axis_)); OP_REQUIRES_OK( ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_)); + if (ctx->HasAttr("dtype")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + if (dtype_ == DT_BFLOAT16) { + OP_REQUIRES( + ctx, + ctx->input_type(0) == DT_BFLOAT16 && + (mode_ == QUANTIZE_MODE_MIN_FIRST || + mode_ == QUANTIZE_MODE_SCALED), + errors::InvalidArgument("Input type bfloat16 is supported only " + "with MIN_FIRST and SCLAED modes")); + } + } else { + dtype_ = DT_FLOAT; + } } void ComputeScalar(OpKernelContext* ctx, float min_range, float max_range) { @@ -605,18 +619,22 @@ class MklQuantizeV2Op : public OpKernel { int round_mode_; int axis_; bool narrow_range_; + DataType dtype_; }; #define REGISTER_QUANTIZE(src_type, dst_type) \ REGISTER_KERNEL_BUILDER( \ Name("_MklQuantizeV2") \ .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklQuantizedOpLabel), \ MklQuantizeV2Op) REGISTER_QUANTIZE(float, qint8); REGISTER_QUANTIZE(float, quint8); +REGISTER_QUANTIZE(bfloat16, qint8); +REGISTER_QUANTIZE(bfloat16, quint8); #undef SET_MKL_LAYOUT diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc index 6bad9720909dbc..4056be48deea70 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc @@ -25,11 +25,13 @@ limitations under the License. namespace tensorflow { -class MklQuantizeV2OpTest : public OpsTestBase {}; +class MklQuantizeV2OpTest : public OpsTestBase, + public ::testing::WithParamInterface {}; -TEST_F(MklQuantizeV2OpTest, small_uint8) { +TEST_P(MklQuantizeV2OpTest, small_uint8) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -37,8 +39,16 @@ TEST_F(MklQuantizeV2OpTest, small_uint8) { .Attr("_kernel", "QuantizedMklOp") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({8}), - {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList( + TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0}); + break; + + default: + AddInputFromArray( + TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0}); + } // min_range = 0 AddInputFromArray(TensorShape({}), {0}); // max_range = 255 @@ -56,9 +66,11 @@ TEST_F(MklQuantizeV2OpTest, small_uint8) { test::ExpectTensorEqual(expected_min, *GetOutput(1)); test::ExpectTensorEqual(expected_max, *GetOutput(2)); } -TEST_F(MklQuantizeV2OpTest, small_int8) { + +TEST_P(MklQuantizeV2OpTest, small_int8) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -66,10 +78,18 @@ TEST_F(MklQuantizeV2OpTest, small_int8) { .Attr("_kernel", "QuantizedMklOp") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5, - -255.0, -80.315, 256.0}); - AddInputFromArray(TensorShape({}), {-50.0}); - AddInputFromArray(TensorShape({}), {127.0}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList( + TensorShape({8}), + {0.0, -1.0, 1.25, -1.75, -24.5, -255.0, -80.315, 256.0}); + break; + default: + AddInputFromArray(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5, + -255.0, -80.315, 256.0}); + } + AddInputFromArray(TensorShape({1}), {-50.0}); + AddInputFromArray(TensorShape({1}), {127.0}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_QINT8, TensorShape({8})); Tensor expected_min(allocator(), DT_FLOAT, TensorShape({})); @@ -82,9 +102,10 @@ TEST_F(MklQuantizeV2OpTest, small_int8) { test::ExpectTensorEqual(expected_max, *GetOutput(2)); } -TEST_F(MklQuantizeV2OpTest, small_minfirst) { +TEST_P(MklQuantizeV2OpTest, small_minfirst) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -92,10 +113,17 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst) { .Attr("_kernel", "QuantizedMklOp") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({8}), - {1.0, 1.25, 1.75, 2, 3.15, 127.0, 255.0, 500.0}); - AddInputFromArray(TensorShape({}), {0}); - AddInputFromArray(TensorShape({}), {255.0f}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList( + TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0}); + break; + default: + AddInputFromArray( + TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0}); + } + AddInputFromArray(TensorShape({1}), {0}); + AddInputFromArray(TensorShape({1}), {255.0f}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_QUINT8, TensorShape({8})); test::FillValues(&expected, {1, 1, 2, 2, 3, 127, 255, 255}); @@ -106,9 +134,10 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst) { EXPECT_NEAR(255.0f, output_max, 1e-5f); } -TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) { +TEST_P(MklQuantizeV2OpTest, small_minfirst_uint) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -116,10 +145,17 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) { .Attr("_kernel", "QuantizedMklOp") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({8}), - {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - AddInputFromArray(TensorShape({}), {0.1}); - AddInputFromArray(TensorShape({}), {0.8}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList(TensorShape({8}), + {0.1, 0.2, 0.3, 0.4, 0.5, 0.599, 0.7, 0.8}); + break; + default: + AddInputFromArray(TensorShape({8}), + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + } + AddInputFromArray(TensorShape({1}), {0.1}); + AddInputFromArray(TensorShape({1}), {0.8}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_QUINT8, TensorShape({8})); test::FillValues(&expected, {32, 64, 96, 128, 159, 191, 223, 255}); @@ -130,9 +166,10 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) { EXPECT_NEAR(0.8f, output_max, 1e-5f); } -TEST_F(MklQuantizeV2OpTest, small_minfirst_int) { +TEST_P(MklQuantizeV2OpTest, small_minfirst_int) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -140,10 +177,18 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) { .Attr("_kernel", "QuantizedMklOp") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({8}), - {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8}); - AddInputFromArray(TensorShape({}), {-0.8}); - AddInputFromArray(TensorShape({}), {-0.1}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList( + TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8}); + + break; + default: + AddInputFromArray( + TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8}); + } + AddInputFromArray(TensorShape({1}), {-0.8}); + AddInputFromArray(TensorShape({1}), {-0.1}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_QUINT8, TensorShape({8})); test::FillValues(&expected, {223, 191, 159, 128, 96, 64, 32, 0}); @@ -154,5 +199,8 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) { EXPECT_NEAR(0.0f, output_max, 1e-5f); } +INSTANTIATE_TEST_SUITE_P(All, MklQuantizeV2OpTest, + ::testing::Values(DT_FLOAT, DT_BFLOAT16)); + } // end namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc index be73d4f8291f7b..ea383ae74db28d 100644 --- a/tensorflow/core/kernels/quantize_op.cc +++ b/tensorflow/core/kernels/quantize_op.cc @@ -17,8 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS +#include + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/cwise_ops.h" @@ -55,7 +58,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; // max_range. // TODO(xbing): Add a new QuantizeOp just taking scale, // rather than min_range and max_range. -template +template class QuantizeV2Op : public OpKernel { public: explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -106,8 +109,26 @@ class QuantizeV2Op : public OpKernel { ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_)); } + void MaybeConvertToFloat(OpKernelContext* ctx, const int idx, + Tensor* converted_tensor) { + if (std::is_same::value) return; + // Convert input tensor of type S to float tensor. + const Tensor& input_tensor = ctx->input(idx); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input_tensor.shape(), + converted_tensor)); + auto flat_input = input_tensor.flat(); + auto d = ctx->eigen_device(); + auto flat_output = converted_tensor->flat(); + flat_output.device(d) = flat_input.template cast(); + } + void Compute(OpKernelContext* ctx) override { - const Tensor& input = ctx->input(0); + // To process bfloat16 input tensor the tensor is converted to float. + Tensor converted_tensor; + MaybeConvertToFloat(ctx, 0, + &converted_tensor); // Does nothing for float input. + const Tensor& input = + (std::is_same::value) ? ctx->input(0) : converted_tensor; const Tensor& input_min_range = ctx->input(1); const Tensor& input_max_range = ctx->input(2); @@ -345,19 +366,33 @@ class QuantizeV2Op : public OpKernel { bool narrow_range_; }; -REGISTER_KERNEL_BUILDER( - Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), - QuantizeV2Op); -REGISTER_KERNEL_BUILDER( - Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), - QuantizeV2Op); -REGISTER_KERNEL_BUILDER( - Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), - QuantizeV2Op); -REGISTER_KERNEL_BUILDER( - Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), - QuantizeV2Op); -REGISTER_KERNEL_BUILDER( - Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), - QuantizeV2Op); +#define REGISTER_CPU(S) \ + REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + QuantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + QuantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + QuantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + QuantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + QuantizeV2Op); + +TF_CALL_float(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); } // namespace tensorflow diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc index 76fe2e9f963bef..39f377e23d2e53 100644 --- a/tensorflow/core/kernels/quantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_op_test.cc @@ -14,15 +14,20 @@ limitations under the License. ==============================================================================*/ #include +#include +#include "gtest/gtest-param-test.h" +#include "gtest/gtest.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { @@ -30,21 +35,37 @@ class QuantizedOpTest : public OpsTestBase { protected: }; +struct TypeParameterizedQuantizeOpTest + : public OpsTestBase, + public ::testing::WithParamInterface {}; + struct ParameterizedQuantizeOpTest : public OpsTestBase, public ::testing::WithParamInterface { }; -TEST_F(QuantizedOpTest, QuantizeV2) { +struct ParamCombinationQuantizeOpTest + : public OpsTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(TypeParameterizedQuantizeOpTest, QuantizeV2) { + const auto dtype = GetParam(); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) .Attr("mode", "MIN_FIRST") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({7}), - {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0}); + switch (dtype) { + case DT_BFLOAT16: + AddInputFromList(TensorShape({7}), + {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0}); + break; + default: + AddInputFromArray(TensorShape({7}), + {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0}); + } // min_range = 0 AddInputFromArray(TensorShape({1}), {0}); // max_range = 255 @@ -57,6 +78,9 @@ TEST_F(QuantizedOpTest, QuantizeV2) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +INSTANTIATE_TEST_SUITE_P(All, TypeParameterizedQuantizeOpTest, + ::testing::Values(DT_FLOAT, DT_BFLOAT16)); + // Creates a tensor with the specified dims, using values chosen from data, // multiplied by (1 + index) along the axis dimension. template @@ -82,10 +106,20 @@ std::vector ScalePerSliceAlongAxis(std::vector dims, int axis, return out; } -TEST_P(ParameterizedQuantizeOpTest, QuantizeV2Quint8Scaled) { - const int axis = GetParam(); +std::vector ConvertFloatToBloat16(const std::vector& data) { + std::vector out(data.size()); + size_t i = 0; + for (auto itr = out.begin(); itr != out.end(); ++itr, ++i) { + *itr = bfloat16(data[i]); + } + return out; +} + +TEST_P(ParamCombinationQuantizeOpTest, QuantizeV2Quint8Scaled) { + const auto dtype = std::get<0>(GetParam()); + const int axis = std::get<1>(GetParam()); TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(dtype)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("T", DataTypeToEnum::v()) @@ -97,10 +131,20 @@ TEST_P(ParameterizedQuantizeOpTest, QuantizeV2Quint8Scaled) { int num_slices = (axis == -1) ? 1 : dims[axis]; // Each channel contains the same 8 values multiplied by (channel + 1). - AddInputFromArray( - TensorShape(dims), - ScalePerSliceAlongAxis( - dims, axis, {-255.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0, 500.0})); + switch (dtype) { + case DT_BFLOAT16: { + auto data = ScalePerSliceAlongAxis( + dims, axis, {-255.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.01, 500.0}); + AddInputFromArray(TensorShape(dims), + ConvertFloatToBloat16(data)); + } break; + + default: + AddInputFromArray( + TensorShape(dims), + ScalePerSliceAlongAxis( + dims, axis, {-255.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0, 500.0})); + } std::vector min_ranges(num_slices), max_ranges(num_slices); for (int slice_idx = 0; slice_idx < num_slices; ++slice_idx) { min_ranges[slice_idx] = (slice_idx + 1) * -255.0; @@ -132,6 +176,11 @@ TEST_P(ParameterizedQuantizeOpTest, QuantizeV2Quint8Scaled) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +INSTANTIATE_TEST_SUITE_P(All, ParamCombinationQuantizeOpTest, + ::testing::Combine(::testing::Values(DT_FLOAT, + DT_BFLOAT16), + ::testing::Values(-1, 1, 3))); + TEST_F(QuantizedOpTest, QuantizeV2Quint8ScaledSmallInputRange) { TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") .Input(FakeInput(DT_FLOAT)) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 331cfa84728eea..945dcb6e10e09c 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -3031,12 +3031,13 @@ REGISTER_OP("QuantizeAndDequantizeV3") }); REGISTER_OP("QuantizeV2") - .Input("input: float") + .Input("input: dtype") .Input("min_range: float") .Input("max_range: float") .Output("output: T") .Output("output_min: float") .Output("output_max: float") + .Attr("dtype: {bfloat16, float} = DT_FLOAT") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") .Attr( diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc index 6dad4c08527498..54c27891a119ea 100644 --- a/tensorflow/core/ops/mkl_array_ops.cc +++ b/tensorflow/core/ops/mkl_array_ops.cc @@ -82,7 +82,7 @@ REGISTER_OP("_MklQuantizedConcatV2") }); REGISTER_OP("_MklQuantizeV2") - .Input("input: float") + .Input("input: dtype") .Input("min_range: float") .Input("max_range: float") .Output("output: T") @@ -96,6 +96,7 @@ REGISTER_OP("_MklQuantizeV2") .Attr("narrow_range: bool = false") .Attr("axis: int = -1") .Attr("ensure_minimum_range: float = 0.01") + .Attr("dtype: {bfloat16, float} = DT_FLOAT") .SetShapeFn(shape_inference::QuantizeV2Shape); REGISTER_OP("_MklDequantize")