diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index 6b526264496e80..05ff773271cc9e 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/matmul_bcast.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 6a652ba0ec5273..5f74e09e97d7d5 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "dnnl.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_kernel_util.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/onednn_env_vars.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif // DNNL_AARCH64_USE_ACL && ENABLE_ONEDNN_OPENMP @@ -814,7 +814,7 @@ class MklMatMulPrimitive : public MklPrimitive { void Execute(const std::shared_ptr& stream, const Tlhs* a_data, const Trhs* b_data, const Toutput* c_data, const MklMatMulParams& matmul_params, void* sp_data, - const std::vector binary_op_fusions_data = {}) { + const std::vector& binary_op_fusions_data = {}) { #if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif @@ -829,9 +829,10 @@ class MklMatMulPrimitive : public MklPrimitive { static_cast(const_cast(c_data)), *stream); context_.sp_mem->set_data_handle(sp_data, *stream); - for (int i = 0; i < num_post_ops_data; ++i) + for (int i = 0; i < num_post_ops_data; ++i) { context_.post_ops_mem[i]->set_data_handle(binary_op_fusions_data[i], *stream); + } #else context_.a_mem->set_data_handle( static_cast(const_cast(a_data))); @@ -861,8 +862,9 @@ class MklMatMulPrimitive : public MklPrimitive { context_.b_mem->set_data_handle(DummyData); context_.c_mem->set_data_handle(DummyData); context_.sp_mem->set_data_handle(DummyData); - for (int i = 0; i < num_post_ops_data; ++i) + for (int i = 0; i < num_post_ops_data; ++i) { context_.post_ops_mem[i]->set_data_handle(DummyData); + } } std::shared_ptr GetPrimitiveDesc() const { @@ -878,7 +880,7 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr c_mem; std::shared_ptr sp_mem; - // Quantization scale related memory + // Quantization scale related memory. std::shared_ptr lhs_scale_mem; std::shared_ptr rhs_scale_mem; std::shared_ptr dst_scale_mem; @@ -896,7 +898,7 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr b_md; std::shared_ptr c_md; - // Quantization scale related memory descriptors + // Quantization scale related memory descriptors. std::shared_ptr lhs_scale_md; std::shared_ptr rhs_scale_md; std::shared_ptr dst_scale_md; @@ -969,7 +971,7 @@ class MklMatMulPrimitive : public MklPrimitive { dnnl::primitive_attr post_ops_attr; dnnl::post_ops post_ops; std::unordered_map is_scale_set; - int binary_post_ops_count = 0; // Keep track op binary fusions + int binary_post_ops_count = 0; // Keep track op binary fusions. if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "lhs_scale") {