Skip to content

Commit

Permalink
QuantizeV2 enabled for bfloat16 input tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
mdfaijul committed Jan 29, 2024
1 parent 9d32123 commit c91dca1
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 60 deletions.
8 changes: 7 additions & 1 deletion tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ If the `axis` attribute is specified, this will be a 1-D tensor whose size
matches the `axis` dimension of the input and output tensors.
END
}
summary: "Quantize the \'input\' tensor of type float to \'output\' tensor of type \'T\'."
attr {
name: "dtype"
description: <<END
Type of the input tensor. Currently QuantizeV2 supports float and bfloat16.
END
}
summary: "Quantize the \'input\' tensor of types float and bfloat16 to \'output\' tensor of type \'T\'."
description: <<END
[min_range, max_range] are scalar floats that specify the range for
the 'input' data. The 'mode' attribute controls exactly which calculations are
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/core/kernels/mkl/mkl_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ class MklQuantizeV2Op : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
if (ctx->HasAttr("dtype")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
if (dtype_ == DT_BFLOAT16) {
OP_REQUIRES(
ctx,
ctx->input_type(0) == DT_BFLOAT16 &&
(mode_ == QUANTIZE_MODE_MIN_FIRST ||
mode_ == QUANTIZE_MODE_SCALED),
errors::InvalidArgument("Input type bfloat16 is supported only "
"with MIN_FIRST and SCLAED modes"));
}
} else {
dtype_ = DT_FLOAT;
}
}

void ComputeScalar(OpKernelContext* ctx, float min_range, float max_range) {
Expand Down Expand Up @@ -605,18 +619,22 @@ class MklQuantizeV2Op : public OpKernel {
int round_mode_;
int axis_;
bool narrow_range_;
DataType dtype_;
};

#define REGISTER_QUANTIZE(src_type, dst_type) \
REGISTER_KERNEL_BUILDER( \
Name("_MklQuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<src_type>("dtype") \
.TypeConstraint<dst_type>("T") \
.Label(mkl_op_registry::kMklQuantizedOpLabel), \
MklQuantizeV2Op<CPUDevice, dst_type, src_type, true>)

REGISTER_QUANTIZE(float, qint8);
REGISTER_QUANTIZE(float, quint8);
REGISTER_QUANTIZE(bfloat16, qint8);
REGISTER_QUANTIZE(bfloat16, quint8);

#undef SET_MKL_LAYOUT

Expand Down
106 changes: 77 additions & 29 deletions tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,30 @@ limitations under the License.

namespace tensorflow {

class MklQuantizeV2OpTest : public OpsTestBase {};
class MklQuantizeV2OpTest : public OpsTestBase,
public ::testing::WithParamInterface<DataType> {};

TEST_F(MklQuantizeV2OpTest, small_uint8) {
TEST_P(MklQuantizeV2OpTest, small_uint8) {
const auto dtype = GetParam();
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(dtype))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "SCALED")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
switch (dtype) {
case DT_BFLOAT16:
AddInputFromList<bfloat16>(
TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
break;

default:
AddInputFromArray<float>(
TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
}
// min_range = 0
AddInputFromArray<float>(TensorShape({}), {0});
// max_range = 255
Expand All @@ -56,20 +66,30 @@ TEST_F(MklQuantizeV2OpTest, small_uint8) {
test::ExpectTensorEqual<float>(expected_min, *GetOutput(1));
test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
}
TEST_F(MklQuantizeV2OpTest, small_int8) {

TEST_P(MklQuantizeV2OpTest, small_int8) {
const auto dtype = GetParam();
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(dtype))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<qint8>::v())
.Attr("mode", "SCALED")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5,
-255.0, -80.315, 256.0});
AddInputFromArray<float>(TensorShape({}), {-50.0});
AddInputFromArray<float>(TensorShape({}), {127.0});
switch (dtype) {
case DT_BFLOAT16:
AddInputFromList<bfloat16>(
TensorShape({8}),
{0.0, -1.0, 1.25, -1.75, -24.5, -255.0, -80.315, 256.0});
break;
default:
AddInputFromArray<float>(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5,
-255.0, -80.315, 256.0});
}
AddInputFromArray<float>(TensorShape({1}), {-50.0});
AddInputFromArray<float>(TensorShape({1}), {127.0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({8}));
Tensor expected_min(allocator(), DT_FLOAT, TensorShape({}));
Expand All @@ -82,20 +102,28 @@ TEST_F(MklQuantizeV2OpTest, small_int8) {
test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
}

TEST_F(MklQuantizeV2OpTest, small_minfirst) {
TEST_P(MklQuantizeV2OpTest, small_minfirst) {
const auto dtype = GetParam();
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(dtype))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{1.0, 1.25, 1.75, 2, 3.15, 127.0, 255.0, 500.0});
AddInputFromArray<float>(TensorShape({}), {0});
AddInputFromArray<float>(TensorShape({}), {255.0f});
switch (dtype) {
case DT_BFLOAT16:
AddInputFromList<bfloat16>(
TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0});
break;
default:
AddInputFromArray<float>(
TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0});
}
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {1, 1, 2, 2, 3, 127, 255, 255});
Expand All @@ -106,20 +134,28 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst) {
EXPECT_NEAR(255.0f, output_max, 1e-5f);
}

TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
TEST_P(MklQuantizeV2OpTest, small_minfirst_uint) {
const auto dtype = GetParam();
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(dtype))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
AddInputFromArray<float>(TensorShape({}), {0.1});
AddInputFromArray<float>(TensorShape({}), {0.8});
switch (dtype) {
case DT_BFLOAT16:
AddInputFromList<bfloat16>(TensorShape({8}),
{0.1, 0.2, 0.3, 0.4, 0.5, 0.599, 0.7, 0.8});
break;
default:
AddInputFromArray<float>(TensorShape({8}),
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
}
AddInputFromArray<float>(TensorShape({1}), {0.1});
AddInputFromArray<float>(TensorShape({1}), {0.8});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {32, 64, 96, 128, 159, 191, 223, 255});
Expand All @@ -130,20 +166,29 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
EXPECT_NEAR(0.8f, output_max, 1e-5f);
}

TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
TEST_P(MklQuantizeV2OpTest, small_minfirst_int) {
const auto dtype = GetParam();
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(dtype))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
AddInputFromArray<float>(TensorShape({}), {-0.8});
AddInputFromArray<float>(TensorShape({}), {-0.1});
switch (dtype) {
case DT_BFLOAT16:
AddInputFromList<bfloat16>(
TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});

break;
default:
AddInputFromArray<float>(
TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
}
AddInputFromArray<float>(TensorShape({1}), {-0.8});
AddInputFromArray<float>(TensorShape({1}), {-0.1});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {223, 191, 159, 128, 96, 64, 32, 0});
Expand All @@ -154,5 +199,8 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
EXPECT_NEAR(0.0f, output_max, 1e-5f);
}

INSTANTIATE_TEST_SUITE_P(All, MklQuantizeV2OpTest,
::testing::Values(DT_FLOAT, DT_BFLOAT16));

} // end namespace tensorflow
#endif // INTEL_MKL
69 changes: 52 additions & 17 deletions tensorflow/core/kernels/quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include <type_traits>

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cwise_ops.h"
Expand Down Expand Up @@ -55,7 +58,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// max_range.
// TODO(xbing): Add a new QuantizeOp just taking scale,
// rather than min_range and max_range.
template <typename Device, typename T>
template <typename Device, typename S, typename T>
class QuantizeV2Op : public OpKernel {
public:
explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
Expand Down Expand Up @@ -106,8 +109,26 @@ class QuantizeV2Op : public OpKernel {
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
}

void MaybeConvertToFloat(OpKernelContext* ctx, const int idx,
Tensor* converted_tensor) {
if (std::is_same<S, float>::value) return;
// Convert input tensor of type S to float tensor.
const Tensor& input_tensor = ctx->input(idx);
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input_tensor.shape(),
converted_tensor));
auto flat_input = input_tensor.flat<S>();
auto d = ctx->eigen_device<Device>();
auto flat_output = converted_tensor->flat<float>();
flat_output.device(d) = flat_input.template cast<float>();
}

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
// To process bfloat16 input tensor the tensor is converted to float.
Tensor converted_tensor;
MaybeConvertToFloat(ctx, 0,
&converted_tensor); // Does nothing for float input.
const Tensor& input =
(std::is_same<S, float>::value) ? ctx->input(0) : converted_tensor;
const Tensor& input_min_range = ctx->input(1);
const Tensor& input_max_range = ctx->input(2);

Expand Down Expand Up @@ -345,19 +366,33 @@ class QuantizeV2Op : public OpKernel {
bool narrow_range_;
};

REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
QuantizeV2Op<CPUDevice, quint8>);
REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
QuantizeV2Op<CPUDevice, qint8>);
REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
QuantizeV2Op<CPUDevice, quint16>);
REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
QuantizeV2Op<CPUDevice, qint16>);
REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
QuantizeV2Op<CPUDevice, qint32>);
#define REGISTER_CPU(S) \
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<S>("dtype") \
.TypeConstraint<quint8>("T"), \
QuantizeV2Op<CPUDevice, S, quint8>); \
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<S>("dtype") \
.TypeConstraint<qint8>("T"), \
QuantizeV2Op<CPUDevice, S, qint8>); \
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<S>("dtype") \
.TypeConstraint<quint16>("T"), \
QuantizeV2Op<CPUDevice, S, quint16>); \
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<S>("dtype") \
.TypeConstraint<qint16>("T"), \
QuantizeV2Op<CPUDevice, S, qint16>); \
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<S>("dtype") \
.TypeConstraint<qint32>("T"), \
QuantizeV2Op<CPUDevice, S, qint32>);

TF_CALL_float(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);
} // namespace tensorflow
Loading

0 comments on commit c91dca1

Please sign in to comment.