Skip to content

Commit

Permalink
Adding leaky_relu fix and oneDNN_primitive_cache fix (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashiqimranintel authored Sep 13, 2022
1 parent 359c3cd commit c3cc4c1
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 301 deletions.
29 changes: 12 additions & 17 deletions tensorflow/core/grappler/optimizers/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2368,25 +2368,16 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched,
<< " contraction=" << contraction.name();

NodeDef fused_node;
fused_node.set_name(contraction.name());
fused_node.set_device(contraction.device());
fused_node.add_input(pad_node_def.input(0)); // 0: input
fused_node.add_input(contraction.input(1)); // 1: filter
// Note: Currently, the attributes of fused op are superset of contraction
// node. So copy it from contraction node before mutation.
fused_node.CopyFrom(contraction);
// 0-th input is the 0-th input of pad node, while the remainder inputs are
// same as the contraction node.
fused_node.set_input(0, pad_node_def.input(0));
fused_node.set_op(kFusedConv3D);

auto* attr = fused_node.mutable_attr();
auto& src_attr = contraction.attr();
(*attr)["T"] = src_attr.at("T");
(*attr)["strides"] = src_attr.at("strides");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["padding"] = src_attr.at("padding");
(*attr)["dilations"] = src_attr.at("dilations");

if (contraction.op() == kFusedConv3D) {
fused_node.add_input(contraction.input(2)); // 2: bias
(*attr)["fused_ops"] = src_attr.at("fused_ops");
(*attr)["num_args"] = src_attr.at("num_args");
} else {
// Set num_args attr explicitly, since there is no default value.
if (!attr->contains("num_args")) {
SetAttrValue(0, &(*attr)["num_args"]);
}

Expand All @@ -2400,6 +2391,10 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched,
paddings.push_back(const_value(i));
SetAttrValue(paddings, &(*attr)["padding_list"]);
}
} else {
VLOG(2) << "Pad fusion with " << contraction.op() << " is invalidated, "
<< "it requires padding dim sizes to be constant.";
return OkStatus();
}

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Expand Down
150 changes: 93 additions & 57 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2259,73 +2259,109 @@ class RemapperFusePadWithFusedConv3D : public RemapperTest {
public:
template <DataType DTYPE>
void RunTest() {
if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL.";
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN.";
auto input_shape = ops::Placeholder::Shape({8, 4, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 1, 3, 128});
auto bias_shape = ops::Placeholder::Shape({128});
auto paddings_shape = ops::Placeholder::Shape({5, 2});
auto strides = {1, 1, 1, 1, 1};
using ::tensorflow::ops::Placeholder;

auto input_t = GenerateTensorWithSetRandom<DTYPE>({8, 4, 32, 32, 3});
auto filter_t = GenerateTensorWithSetRandom<DTYPE>({1, 1, 1, 3, 128});
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({128});
// Empty string denotes no activation.
for (const string& activation : {"", "Relu", "Relu6", "Elu", "LeakyRelu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({8, 4, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 1, 3, 128});
auto bias_shape = ops::Placeholder::Shape({128});
auto paddings_shape = ops::Placeholder::Shape({5, 2});
auto strides = {1, 1, 1, 1, 1};

auto input_t = GenerateTensorWithSetRandom<DTYPE>({8, 4, 32, 32, 3});
auto filter_t = GenerateTensorWithSetRandom<DTYPE>({1, 1, 1, 3, 128});
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({128});

auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);

auto padding_const = ops::Const(s.WithOpName("padding"),
{0, 0, 1, 1, 1, 1, 1, 1, 0, 0}, {5, 2});
auto pad = ops::Pad(s.WithOpName("pad"), input, padding_const);
auto conv = ops::Conv3D(s.WithOpName("conv"), pad, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
auto padding_const = ops::Const(s.WithOpName("padding"),
{0, 0, 1, 1, 1, 1, 1, 1, 0, 0}, {5, 2});
auto pad = ops::Pad(s.WithOpName("pad"), input, padding_const);
auto conv =
ops::Conv3D(s.WithOpName("conv"), pad, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
float leakyrelu_alpha = 0.5;
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");

// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
} else if (activation == "LeakyRelu") {
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
return ops::Identity(
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
}

Remapper optimizer(RewriterConfig::ON);
GraphDef output_1;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output_1));
item.graph = std::move(output_1);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
return ops::Identity(fetch, bias);
}();

int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedConv3D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
}
EXPECT_EQ(found, 1);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
if (DTYPE == DT_BFLOAT16)
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
Remapper optimizer(RewriterConfig::ON);
GraphDef output_1;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output_1));
item.graph = std::move(output_1);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

string fused_node_name;
std::vector<string> expected_fused_ops = {"BiasAdd"};
if (activation.empty()) {
fused_node_name = "bias_add";
} else {
fused_node_name = "activation";
expected_fused_ops.push_back(activation);
}
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == fused_node_name) {
EXPECT_EQ(node.op(), "_FusedConv3D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), expected_fused_ops.size());
for (int i = 0; i < fused_ops.size(); ++i) {
EXPECT_EQ(fused_ops[i], expected_fused_ops[i]);
}
if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
}
found++;
}
}
EXPECT_EQ(found, 1);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
if (DTYPE == DT_BFLOAT16)
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
}
}
};

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/mkl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ tf_cc_test_mkl(
tf_mkl_kernel_library(
name = "mkl_quantize_op",
srcs = ["mkl_quantize_op.cc"],
hdrs = ["mkl_quant_dequant.h"],
deps = [
"//tensorflow/core/kernels:quantized_ops",
"//tensorflow/core/graph:mkl_graph_util",
Expand Down Expand Up @@ -280,6 +281,7 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_dequantize_op",
srcs = ["mkl_dequantize_op.cc"],
hdrs = ["mkl_quant_dequant.h"],
deps = [
"@gemmlowp",
"//tensorflow/core:array_ops_op_lib",
Expand Down
43 changes: 14 additions & 29 deletions tensorflow/core/kernels/mkl/mkl_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,9 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include "dnnl.hpp"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/kernels/mkl/mkl_quant_dequant.h"

using dnnl::primitive_attr;
using dnnl::stream;
Expand Down Expand Up @@ -71,10 +64,6 @@ class MklDequantizeOp : public OpKernel {
MklDnnData<T> src(&cpu_engine);
MklDnnData<float> dst(&cpu_engine);

std::shared_ptr<stream> reorder_stream;
MklDnnThreadPool eigen_tp(ctx);
reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine));

memory::format_tag dst_layout_type;
switch (src_tf_shape.dims()) {
case 1:
Expand Down Expand Up @@ -105,7 +94,6 @@ class MklDequantizeOp : public OpKernel {
auto src_md = memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);

src.SetUsrMem(src_md, &src_tensor);
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);

Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape;
Expand All @@ -122,7 +110,6 @@ class MklDequantizeOp : public OpKernel {
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
output_mkl_shape, native_format);
dst.SetUsrMem(dst_md, output_tensor);
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);

// The quantization logic here for mode SCALED is similar to the logic
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
Expand All @@ -145,21 +132,19 @@ class MklDequantizeOp : public OpKernel {
} else {
scale_factor = max_range / v_max;
}
std::vector<float> scales;
scales.push_back(scale_factor);
primitive_attr attr;
attr.set_output_scales(0, scales);
std::vector<primitive> net;

// Create reorder primitive and then execute.
auto reorder_pd =
ReorderPd(cpu_engine, src.GetUsrMem()->get_desc(), cpu_engine,
dst.GetUsrMem()->get_desc(), attr);
net.push_back(reorder(reorder_pd));
std::vector<std::unordered_map<int, memory>> reorder_net_args;
reorder_net_args.push_back(
{{DNNL_ARG_FROM, *src.GetUsrMem()}, {DNNL_ARG_TO, *dst.GetUsrMem()}});
execute_primitives(net, reorder_stream, reorder_net_args);
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
fwdParams.dtypes.append(typeid(T).name());
fwdParams.post_op_params.name = "scale";
fwdParams.post_op_params.param.push_back(scale_factor);
MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
dst.GetUsrMem(), fwdParams);
std::shared_ptr<stream> cpu_stream;
MklDnnThreadPool eigen_tp(ctx);
cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(),
cpu_stream);

} catch (dnnl::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
Expand Down
Loading

0 comments on commit c3cc4c1

Please sign in to comment.