diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 9bf853ed05e658..55ddcedc6bbdac 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -2368,25 +2368,16 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, << " contraction=" << contraction.name(); NodeDef fused_node; - fused_node.set_name(contraction.name()); - fused_node.set_device(contraction.device()); - fused_node.add_input(pad_node_def.input(0)); // 0: input - fused_node.add_input(contraction.input(1)); // 1: filter + // Note: Currently, the attributes of fused op are superset of contraction + // node. So copy it from contraction node before mutation. + fused_node.CopyFrom(contraction); + // 0-th input is the 0-th input of pad node, while the remainder inputs are + // same as the contraction node. + fused_node.set_input(0, pad_node_def.input(0)); fused_node.set_op(kFusedConv3D); - auto* attr = fused_node.mutable_attr(); - auto& src_attr = contraction.attr(); - (*attr)["T"] = src_attr.at("T"); - (*attr)["strides"] = src_attr.at("strides"); - (*attr)["data_format"] = src_attr.at("data_format"); - (*attr)["padding"] = src_attr.at("padding"); - (*attr)["dilations"] = src_attr.at("dilations"); - - if (contraction.op() == kFusedConv3D) { - fused_node.add_input(contraction.input(2)); // 2: bias - (*attr)["fused_ops"] = src_attr.at("fused_ops"); - (*attr)["num_args"] = src_attr.at("num_args"); - } else { + // Set num_args attr explicitly, since there is no default value. + if (!attr->contains("num_args")) { SetAttrValue(0, &(*attr)["num_args"]); } @@ -2400,6 +2391,10 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, paddings.push_back(const_value(i)); SetAttrValue(paddings, &(*attr)["padding_list"]); } + } else { + VLOG(2) << "Pad fusion with " << contraction.op() << " is invalidated, " + << "it requires padding dim sizes to be constant."; + return OkStatus(); } utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index b10c6f35a4bfc0..648db1ca9f7325 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -2259,73 +2259,109 @@ class RemapperFusePadWithFusedConv3D : public RemapperTest { public: template void RunTest() { - if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL."; - using ::tensorflow::ops::Placeholder; - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; - auto input_shape = ops::Placeholder::Shape({8, 4, 32, 32, 3}); - auto filter_shape = ops::Placeholder::Shape({1, 1, 1, 3, 128}); - auto bias_shape = ops::Placeholder::Shape({128}); - auto paddings_shape = ops::Placeholder::Shape({5, 2}); - auto strides = {1, 1, 1, 1, 1}; + using ::tensorflow::ops::Placeholder; - auto input_t = GenerateTensorWithSetRandom({8, 4, 32, 32, 3}); - auto filter_t = GenerateTensorWithSetRandom({1, 1, 1, 3, 128}); - auto bias_t = GenerateTensorWithSetRandom({128}); + // Empty string denotes no activation. + for (const string& activation : {"", "Relu", "Relu6", "Elu", "LeakyRelu"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto input_shape = ops::Placeholder::Shape({8, 4, 32, 32, 3}); + auto filter_shape = ops::Placeholder::Shape({1, 1, 1, 3, 128}); + auto bias_shape = ops::Placeholder::Shape({128}); + auto paddings_shape = ops::Placeholder::Shape({5, 2}); + auto strides = {1, 1, 1, 1, 1}; + + auto input_t = GenerateTensorWithSetRandom({8, 4, 32, 32, 3}); + auto filter_t = GenerateTensorWithSetRandom({1, 1, 1, 3, 128}); + auto bias_t = GenerateTensorWithSetRandom({128}); - auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape); - auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape); + auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape); - auto padding_const = ops::Const(s.WithOpName("padding"), - {0, 0, 1, 1, 1, 1, 1, 1, 0, 0}, {5, 2}); - auto pad = ops::Pad(s.WithOpName("pad"), input, padding_const); - auto conv = ops::Conv3D(s.WithOpName("conv"), pad, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add); + auto padding_const = ops::Const(s.WithOpName("padding"), + {0, 0, 1, 1, 1, 1, 1, 1, 0, 0}, {5, 2}); + auto pad = ops::Pad(s.WithOpName("pad"), input, padding_const); + auto conv = + ops::Conv3D(s.WithOpName("conv"), pad, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + float leakyrelu_alpha = 0.5; + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, bias_add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, bias_add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, bias_add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); + return ops::Identity( + fetch, ops::internal::LeakyRelu(activate, bias_add, attr)); + } - Remapper optimizer(RewriterConfig::ON); - GraphDef output_1; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output_1)); - item.graph = std::move(output_1); - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + return ops::Identity(fetch, bias); + }(); - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "bias_add") { - EXPECT_EQ(node.op(), "_FusedConv3D"); - ASSERT_GE(node.input_size(), 3); - EXPECT_EQ(node.input(0), "input"); - EXPECT_EQ(node.input(1), "filter"); - EXPECT_EQ(node.attr().at("num_args").i(), 1); - EXPECT_EQ(node.input(2), "bias"); - const auto fused_ops = node.attr().at("fused_ops").list().s(); - ASSERT_EQ(fused_ops.size(), 1); - EXPECT_EQ(fused_ops[0], "BiasAdd"); - found++; + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); } - } - EXPECT_EQ(found, 1); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - ASSERT_EQ(tensors_expected.size(), 1); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - ASSERT_EQ(tensors.size(), 1); - if (DTYPE == DT_BFLOAT16) - test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2); - else - test::ExpectClose(tensors[0], tensors_expected[0], 1e-6); + Remapper optimizer(RewriterConfig::ON); + GraphDef output_1; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output_1)); + item.graph = std::move(output_1); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + string fused_node_name; + std::vector expected_fused_ops = {"BiasAdd"}; + if (activation.empty()) { + fused_node_name = "bias_add"; + } else { + fused_node_name = "activation"; + expected_fused_ops.push_back(activation); + } + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == fused_node_name) { + EXPECT_EQ(node.op(), "_FusedConv3D"); + ASSERT_GE(node.input_size(), 3); + EXPECT_EQ(node.input(0), "input"); + EXPECT_EQ(node.input(1), "filter"); + EXPECT_EQ(node.attr().at("num_args").i(), 1); + EXPECT_EQ(node.input(2), "bias"); + const auto fused_ops = node.attr().at("fused_ops").list().s(); + ASSERT_EQ(fused_ops.size(), expected_fused_ops.size()); + for (int i = 0; i < fused_ops.size(); ++i) { + EXPECT_EQ(fused_ops[i], expected_fused_ops[i]); + } + if (activation == "LeakyRelu") { + EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha); + } + found++; + } + } + EXPECT_EQ(found, 1); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + if (DTYPE == DT_BFLOAT16) + test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2); + else + test::ExpectClose(tensors[0], tensors_expected[0], 1e-6); + } } }; diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index bdf8b2b9714077..ce5940eca670a1 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -142,6 +142,7 @@ tf_cc_test_mkl( tf_mkl_kernel_library( name = "mkl_quantize_op", srcs = ["mkl_quantize_op.cc"], + hdrs = ["mkl_quant_dequant.h"], deps = [ "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/graph:mkl_graph_util", @@ -280,6 +281,7 @@ tf_mkl_kernel_library( tf_mkl_kernel_library( name = "mkl_dequantize_op", srcs = ["mkl_dequantize_op.cc"], + hdrs = ["mkl_quant_dequant.h"], deps = [ "@gemmlowp", "//tensorflow/core:array_ops_op_lib", diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index 3165fa69b7e744..c8ded409bdb3e7 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -17,16 +17,9 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "dnnl.hpp" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/quantization_utils.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_util.h" +#include "tensorflow/core/kernels/mkl/mkl_quant_dequant.h" using dnnl::primitive_attr; using dnnl::stream; @@ -71,10 +64,6 @@ class MklDequantizeOp : public OpKernel { MklDnnData src(&cpu_engine); MklDnnData dst(&cpu_engine); - std::shared_ptr reorder_stream; - MklDnnThreadPool eigen_tp(ctx); - reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine)); - memory::format_tag dst_layout_type; switch (src_tf_shape.dims()) { case 1: @@ -105,7 +94,6 @@ class MklDequantizeOp : public OpKernel { auto src_md = memory::desc(src_dims, MklDnnType(), dst_layout_type); src.SetUsrMem(src_md, &src_tensor); - src.SetUsrMemDataHandle(&src_tensor, reorder_stream); Tensor* output_tensor = nullptr; MklDnnShape output_mkl_shape; @@ -122,7 +110,6 @@ class MklDequantizeOp : public OpKernel { AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape, output_mkl_shape, native_format); dst.SetUsrMem(dst_md, output_tensor); - dst.SetUsrMemDataHandle(output_tensor, reorder_stream); // The quantization logic here for mode SCALED is similar to the logic // in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3. @@ -145,21 +132,19 @@ class MklDequantizeOp : public OpKernel { } else { scale_factor = max_range / v_max; } - std::vector scales; - scales.push_back(scale_factor); - primitive_attr attr; - attr.set_output_scales(0, scales); - std::vector net; - - // Create reorder primitive and then execute. - auto reorder_pd = - ReorderPd(cpu_engine, src.GetUsrMem()->get_desc(), cpu_engine, - dst.GetUsrMem()->get_desc(), attr); - net.push_back(reorder(reorder_pd)); - std::vector> reorder_net_args; - reorder_net_args.push_back( - {{DNNL_ARG_FROM, *src.GetUsrMem()}, {DNNL_ARG_TO, *dst.GetUsrMem()}}); - execute_primitives(net, reorder_stream, reorder_net_args); + MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md); + fwdParams.dtypes.append(typeid(T).name()); + fwdParams.post_op_params.name = "scale"; + fwdParams.post_op_params.param.push_back(scale_factor); + MklReorderWithScalePrimitive* reorder_prim = + MklReorderWithScalePrimitiveFactory::Get(src.GetUsrMem(), + dst.GetUsrMem(), fwdParams); + std::shared_ptr cpu_stream; + MklDnnThreadPool eigen_tp(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine())); + reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(), + cpu_stream); + } catch (dnnl::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + diff --git a/tensorflow/core/kernels/mkl/mkl_quant_dequant.h b/tensorflow/core/kernels/mkl/mkl_quant_dequant.h new file mode 100644 index 00000000000000..5eb8343a17887f --- /dev/null +++ b/tensorflow/core/kernels/mkl/mkl_quant_dequant.h @@ -0,0 +1,220 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#define EIGEN_USE_THREADS + +#include "dnnl.hpp" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_util.h" +#ifdef DNNL_AARCH64_USE_ACL +#include "tensorflow/core/platform/mutex.h" +#endif + +using dnnl::primitive_attr; +using dnnl::prop_kind; +using dnnl::reorder; +using dnnl::stream; + +namespace { +enum { + QUANTIZE_MODE_MIN_COMBINED, + QUANTIZE_MODE_MIN_FIRST, + QUANTIZE_MODE_SCALED, +}; +enum { + // Round half away from zero: if the fraction of y is exactly 0.5, then + // round(y) = y + 0.5 if y > 0 + // round(y) = y - 0.5 if y < 0 + // E.g., -5.5 gets rounded to -6, -5.4 goes to -5, + // 5.4 goes to 5, and 5.5 goes to 6. + ROUND_HALF_AWAY_FROM_ZERO, + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; +} // namespace + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; + +struct MklReorderWithScaleFwdParams { + memory::dims src_dims; + memory::desc src_md; + memory::desc dst_md; + string dtypes = string(""); + struct PostOpParam { + string name; + std::vector param; + }; + PostOpParam post_op_params; + + MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md, + memory::desc dst_md) + : src_dims(src_dims), src_md(src_md), dst_md(dst_md) {} +}; + +class MklReorderWithScalePrimitive : public MklPrimitive { + public: + explicit MklReorderWithScalePrimitive( + const MklReorderWithScaleFwdParams& fwdParams) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + // Create reorder primitive + Setup(fwdParams); + } + + ~MklReorderWithScalePrimitive() {} + + std::shared_ptr GetPrimitive() { return context_.reorder_prim; } + + void Execute(void* src_data, void* dst_data, + std::shared_ptr reorder_stream) { +#ifdef DNNL_AARCH64_USE_ACL + mutex_lock lock(primitive_execution_mu_); +#endif +#ifndef ENABLE_ONEDNN_OPENMP + context_.src_mem->set_data_handle(src_data, *reorder_stream); + context_.dst_mem->set_data_handle(dst_data, *reorder_stream); +#else + context_.src_mem->set_data_handle(src_data); + context_.dst_mem->set_data_handle(dst_data); +#endif // !ENABLE_ONEDNN_OPENMP + context_.reorder_prim->execute(*reorder_stream, context_.prim_args); + // After execution, set data handle back. + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + private: + // Primitive reuse context for reorder + struct ReorderContext { + // MKL-DNN memory + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // Reorder primitive descriptor and primitive + std::shared_ptr reorder_pd; + std::shared_ptr reorder_prim; + + // Stream and primitive vector + std::shared_ptr reorder_stream; + + std::unordered_map prim_args; + + ReorderContext() + : src_mem(nullptr), + dst_mem(nullptr), + reorder_pd(nullptr), + reorder_prim(nullptr) {} + } context_; + + // Reorder primitive setup + void Setup(const MklReorderWithScaleFwdParams& fwdParams) { + // Create memory descriptors for reorder data with specified format + context_.src_mem.reset( + new memory(fwdParams.src_md, cpu_engine_, DummyData)); + context_.dst_mem.reset( + new memory(fwdParams.dst_md, cpu_engine_, DummyData)); + + // Check if there is any fusion as post-ops + auto const& post_op_params = fwdParams.post_op_params; + dnnl::primitive_attr post_ops_attr; + + DCHECK(post_op_params.name == "scale"); + DCHECK_EQ(post_op_params.param.size(), 1); + std::vector scales; + scales.push_back(post_op_params.param[0]); + post_ops_attr.set_output_scales(0, scales); + + context_.reorder_pd.reset( + new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_, + context_.dst_mem->get_desc(), post_ops_attr)); + + // Create reorder primitive + context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); + context_.prim_args.insert({DNNL_ARG_FROM, *context_.src_mem}); + context_.prim_args.insert({DNNL_ARG_TO, *context_.dst_mem}); + } + +#ifdef DNNL_AARCH64_USE_ACL + mutex primitive_execution_mu_; +#endif +}; + +template +class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory { + public: + static MklReorderWithScalePrimitive* Get( + const memory* from, const memory* to, + const MklReorderWithScaleFwdParams& fwdParams) { + // Try to find a suitable primitive from the cached pool + auto reorderPrim = static_cast( + MklReorderWithScalePrimitiveFactory::GetInstance().GetReorder( + from, to, fwdParams)); + if (reorderPrim == nullptr) { + reorderPrim = new MklReorderWithScalePrimitive(fwdParams); + MklReorderWithScalePrimitiveFactory::GetInstance().SetReorder( + from, to, reorderPrim, fwdParams); + } + return reorderPrim; + } + + static MklReorderWithScalePrimitiveFactory& GetInstance() { + static MklReorderWithScalePrimitiveFactory instance_; + return instance_; + } + + private: + MklReorderWithScalePrimitiveFactory() {} + ~MklReorderWithScalePrimitiveFactory() {} + + static string CreateKey(const memory* from, const memory* to, + const MklReorderWithScaleFwdParams& fwdParams) { + FactoryKeyCreator key_creator; + key_creator.AddAsKey(MklReorderPrimitiveFactory::CreateKey(from, to)); + // Generate key for post-op scale + if (fwdParams.post_op_params.name == "scale") { + DCHECK_EQ(fwdParams.post_op_params.param.size(), 1); + key_creator.AddAsKey(fwdParams.post_op_params.name); + key_creator.AddAsKey(fwdParams.post_op_params.param[0]); + } else { + return string("not_a_key"); + } + + return key_creator.GetKey(); + } + + MklPrimitive* GetReorder(const memory* from, const memory* to, + const MklReorderWithScaleFwdParams& fwdParams) { + string key = CreateKey(from, to, fwdParams); + return this->GetOp(key); + } + + void SetReorder(const memory* from, const memory* to, MklPrimitive* op, + const MklReorderWithScaleFwdParams& fwdParams) { + string key = CreateKey(from, to, fwdParams); + this->SetOp(key, op); + } +}; +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index 31698a798c0b82..2c3f55d1012eb7 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -14,211 +14,15 @@ limitations under the License. ==============================================================================*/ #ifdef INTEL_MKL - -#define EIGEN_USE_THREADS - -#include "dnnl.hpp" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/mkl_graph_util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_util.h" -#ifdef DNNL_AARCH64_USE_ACL -#include "tensorflow/core/platform/mutex.h" -#endif - +#include "tensorflow/core/kernels/mkl/mkl_quant_dequant.h" using dnnl::primitive_attr; using dnnl::prop_kind; using dnnl::reorder; using dnnl::stream; -namespace { -enum { - QUANTIZE_MODE_MIN_COMBINED, - QUANTIZE_MODE_MIN_FIRST, - QUANTIZE_MODE_SCALED, -}; -enum { - // Round half away from zero: if the fraction of y is exactly 0.5, then - // round(y) = y + 0.5 if y > 0 - // round(y) = y - 0.5 if y < 0 - // E.g., -5.5 gets rounded to -6, -5.4 goes to -5, - // 5.4 goes to 5, and 5.5 goes to 6. - ROUND_HALF_AWAY_FROM_ZERO, - // Round half to even: if the fraction of y is exactly 0.5, then round(y) is - // the nearest even integer to y. - // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes - // -24, and -24.5 gets rounded to 24. - ROUND_HALF_TO_EVEN, -}; -} // namespace - -namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; - -struct MklReorderWithScaleFwdParams { - memory::dims src_dims; - memory::desc src_md; - memory::desc dst_md; - string dtypes = string(""); - struct PostOpParam { - string name; - std::vector param; - }; - PostOpParam post_op_params; - - MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md, - memory::desc dst_md) - : src_dims(src_dims), src_md(src_md), dst_md(dst_md) {} -}; - -class MklReorderWithScalePrimitive : public MklPrimitive { - public: - explicit MklReorderWithScalePrimitive( - const MklReorderWithScaleFwdParams& fwdParams) - : MklPrimitive(engine(engine::kind::cpu, 0)) { - // Create reorder primitive - Setup(fwdParams); - } - - ~MklReorderWithScalePrimitive() {} - - std::shared_ptr GetPrimitive() { return context_.reorder_prim; } - - void Execute(void* src_data, void* dst_data, - std::shared_ptr reorder_stream) { -#ifdef DNNL_AARCH64_USE_ACL - mutex_lock lock(primitive_execution_mu_); -#endif -#ifndef ENABLE_ONEDNN_OPENMP - context_.src_mem->set_data_handle(src_data, *reorder_stream); - context_.dst_mem->set_data_handle(dst_data, *reorder_stream); -#else - context_.src_mem->set_data_handle(src_data); - context_.dst_mem->set_data_handle(dst_data); -#endif // !ENABLE_ONEDNN_OPENMP - context_.reorder_prim->execute(*reorder_stream, context_.prim_args); - // After execution, set data handle back. - context_.src_mem->set_data_handle(DummyData); - context_.dst_mem->set_data_handle(DummyData); - } - - private: - // Primitive reuse context for reorder - struct ReorderContext { - // MKL-DNN memory - std::shared_ptr src_mem; - std::shared_ptr dst_mem; - - // Reorder primitive descriptor and primitive - std::shared_ptr reorder_pd; - std::shared_ptr reorder_prim; - - // Stream and primitive vector - std::shared_ptr reorder_stream; - - std::unordered_map prim_args; - - ReorderContext() - : src_mem(nullptr), - dst_mem(nullptr), - reorder_pd(nullptr), - reorder_prim(nullptr) {} - } context_; - - // Reorder primitive setup - void Setup(const MklReorderWithScaleFwdParams& fwdParams) { - // Create memory descriptors for reorder data with specified format - context_.src_mem.reset( - new memory(fwdParams.src_md, cpu_engine_, DummyData)); - context_.dst_mem.reset( - new memory(fwdParams.dst_md, cpu_engine_, DummyData)); - - // Check if there is any fusion as post-ops - auto const& post_op_params = fwdParams.post_op_params; - dnnl::primitive_attr post_ops_attr; - - DCHECK(post_op_params.name == "scale"); - DCHECK_EQ(post_op_params.param.size(), 1); - std::vector scales; - scales.push_back(post_op_params.param[0]); - post_ops_attr.set_output_scales(0, scales); - - context_.reorder_pd.reset( - new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_, - context_.dst_mem->get_desc(), post_ops_attr)); - - // Create reorder primitive - context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); - context_.prim_args.insert({DNNL_ARG_FROM, *context_.src_mem}); - context_.prim_args.insert({DNNL_ARG_TO, *context_.dst_mem}); - } - -#ifdef DNNL_AARCH64_USE_ACL - mutex primitive_execution_mu_; -#endif -}; - -template -class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory { - public: - static MklReorderWithScalePrimitive* Get( - const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { - // Try to find a suitable primitive from the cached pool - auto reorderPrim = static_cast( - MklReorderWithScalePrimitiveFactory::GetInstance().GetReorder( - from, to, fwdParams)); - if (reorderPrim == nullptr) { - reorderPrim = new MklReorderWithScalePrimitive(fwdParams); - MklReorderWithScalePrimitiveFactory::GetInstance().SetReorder( - from, to, reorderPrim, fwdParams); - } - return reorderPrim; - } - - static MklReorderWithScalePrimitiveFactory& GetInstance() { - static MklReorderWithScalePrimitiveFactory instance_; - return instance_; - } - - private: - MklReorderWithScalePrimitiveFactory() {} - ~MklReorderWithScalePrimitiveFactory() {} - - static string CreateKey(const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { - FactoryKeyCreator key_creator; - key_creator.AddAsKey(MklReorderPrimitiveFactory::CreateKey(from, to)); - // Generate key for post-op scale - if (fwdParams.post_op_params.name == "scale") { - DCHECK_EQ(fwdParams.post_op_params.param.size(), 1); - key_creator.AddAsKey(fwdParams.post_op_params.name); - key_creator.AddAsKey(fwdParams.post_op_params.param[0]); - } else { - return string("not_a_key"); - } - - return key_creator.GetKey(); - } - - MklPrimitive* GetReorder(const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { - string key = CreateKey(from, to, fwdParams); - return this->GetOp(key); - } - - void SetReorder(const memory* from, const memory* to, MklPrimitive* op, - const MklReorderWithScaleFwdParams& fwdParams) { - string key = CreateKey(from, to, fwdParams); - this->SetOp(key, op); - } -}; - // Quantizes a tensor from float to T, with user-specified min_range and // max_range. +namespace tensorflow { template class MklQuantizeV2Op : public OpKernel { public: