diff --git a/cinn/frontend/base_builder.cc b/cinn/frontend/base_builder.cc index fe3e53c668063..9834499c994e6 100644 --- a/cinn/frontend/base_builder.cc +++ b/cinn/frontend/base_builder.cc @@ -1,27 +1,40 @@ #include "cinn/frontend/base_builder.h" +#include +#include #include #include +#include #include "cinn/common/common.h" #include "cinn/common/context.h" +#include "cinn/common/type.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" namespace cinn { namespace frontend { +using common::Context; +using common::Type; +using hlir::framework::AttrMapType; +using hlir::framework::Operator; +using hlir::framework::shape_t; + +BaseBuilder::BaseBuilder(const std::string& name) : name_(name) {} + Program BaseBuilder::Build() { Program program{std::move(instrs_), std::move(inputs_)}; program.Validate(); return program; } -Placeholder BaseBuilder::CreateInput(const common::Type& type, - const std::vector& shape, - const std::string& id_hint) { +Placeholder BaseBuilder::CreateInput(const Type& type, const std::vector& shape, const std::string& id_hint) { if (!id_hint.empty()) { CheckVarNameValid(id_hint); } - std::string id = id_hint.empty() ? common::Context::Global().NewName("placeholder") : id_hint; + std::string id = id_hint.empty() ? Context::Global().NewName("placeholder") : id_hint; inputs_.emplace_back(id); auto& var = inputs_.back(); @@ -30,5 +43,30 @@ Placeholder BaseBuilder::CreateInput(const common::Type& type, return Placeholder(var); } +void BaseBuilder::InferShape(Instruction instr) const { + using shape_func_t = std::function(const std::vector&, const AttrMapType&)>; + using type_func_t = std::function(const std::vector&, const AttrMapType&)>; + const auto& op_infershape = Operator::GetAttrs("infershape"); + const auto& op_inferdtype = Operator::GetAttrs("inferdtype"); + + size_t size = instr->inputs.size(); + std::vector in_shapes(size); + std::vector in_types(size); + std::transform( + instr->inputs.begin(), instr->inputs.end(), in_shapes.begin(), [](const Variable& var) { return var->shape; }); + std::transform( + instr->inputs.begin(), instr->inputs.end(), in_types.begin(), [](const Variable& var) { return var->type; }); + + auto key = Operator::Get(instr->op_type); + auto out_shapes = op_infershape[key](in_shapes, instr->attrs); + auto out_types = op_inferdtype[key](in_types, instr->attrs); + + auto& outs = instr->outputs; + for (size_t i = 0; i < outs.size(); i++) { + outs[i]->shape = out_shapes[i]; + outs[i]->type = out_types[i]; + } +} + } // namespace frontend } // namespace cinn diff --git a/cinn/frontend/base_builder.h b/cinn/frontend/base_builder.h index b621a5343cdc7..927849d196ea4 100644 --- a/cinn/frontend/base_builder.h +++ b/cinn/frontend/base_builder.h @@ -1,18 +1,18 @@ #pragma once #include -#include #include #include "cinn/common/type.h" #include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/op.h" namespace cinn { namespace frontend { class BaseBuilder { public: - explicit BaseBuilder(const std::string& name) : name_(name) {} + explicit BaseBuilder(const std::string& name); Program Build(); @@ -26,6 +26,8 @@ class BaseBuilder { protected: void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); } + void InferShape(Instruction instr) const; + std::string name_; std::vector instrs_; std::vector inputs_; diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 8bfb1330fee12..6e289ad473130 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -1,8 +1,8 @@ #include "cinn/frontend/net_builder.h" #include -#include #include +#include #include "cinn/frontend/syntax.h" @@ -11,6 +11,7 @@ namespace frontend { Variable NetBuilder::add(const Variable& a, const Variable& b) { Instruction instr("elementwise_add", {a, b}); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -19,6 +20,7 @@ Variable NetBuilder::mul(const Variable& a, const Variable& b, int x_num_col_dim Instruction instr("mul", {a, b}); instr.SetAttr("x_num_col_dims", x_num_col_dims); instr.SetAttr("y_num_col_dims", y_num_col_dims); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -28,6 +30,7 @@ Variable NetBuilder::mulbias( Instruction instr("mulbias", {a, b, c}); instr.SetAttr("x_num_col_dims", x_num_col_dims); instr.SetAttr("y_num_col_dims", y_num_col_dims); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(1); } @@ -35,6 +38,7 @@ Variable NetBuilder::mulbias( Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int axis) { Instruction instr("elementwise_add", {a, b}); instr.SetAttr("axis", axis); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -42,12 +46,14 @@ Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int a Variable NetBuilder::elementwise_mul(const Variable& a, const Variable& b, int axis) { Instruction instr("elementwise_mul", {a, b}); instr.SetAttr("axis", axis); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } Variable NetBuilder::relu(const Variable& a) { Instruction instr("relu", {a}); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -55,6 +61,7 @@ Variable NetBuilder::relu(const Variable& a) { Variable NetBuilder::relu6(const Variable& a, float threshold) { Instruction instr("relu6", {a}); instr.SetAttr("threshold", threshold); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -75,6 +82,7 @@ Variable NetBuilder::conv2d(const Variable& a, instr.SetAttr("groups", groups); instr.SetAttr("data_format", data_format); instr.SetAttr("padding_algorithm", padding_algorithm); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -95,6 +103,7 @@ Variable NetBuilder::depthwise_conv2d(const Variable& a, instr.SetAttr("groups", groups); instr.SetAttr("data_format", data_format); instr.SetAttr("padding_algorithm", padding_algorithm); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -122,6 +131,7 @@ Variable NetBuilder::pool2d(const Variable& a, instr.SetAttr("data_format", data_format); instr.SetAttr("adaptive", adaptive); instr.SetAttr("padding_algorithm", padding_algorithm); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -139,6 +149,7 @@ Variable NetBuilder::batchnorm(const Variable& a, instr.SetAttr("epsilon", epsilon); instr.SetAttr("momentum", momentum); instr.SetAttr("data_layout", data_layout); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -148,6 +159,7 @@ Variable NetBuilder::scale(const Variable& a, float scale, float bias, bool bias instr.SetAttr("scale", scale); instr.SetAttr("bias", bias); instr.SetAttr("bias_after_scale", bias_after_scale); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -156,12 +168,14 @@ Variable NetBuilder::softmax(const Variable& a, int axis, const std::string& dat Instruction instr("softmax", {a}); instr.SetAttr("axis", axis); instr.SetAttr("data_format", data_format); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } Variable NetBuilder::sigmoid(const Variable& a) { Instruction instr("sigmoid", {a}); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -178,6 +192,7 @@ Variable NetBuilder::slice(const Variable& a, instr.SetAttr("ends", ends); instr.SetAttr("infer_flags", infer_flags); instr.SetAttr("decrease_axis", decrease_axis); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } @@ -186,6 +201,7 @@ Variable NetBuilder::dropout_infer(const Variable& a, float dropout_prob, const Instruction instr("dropout_infer", {a}); instr.SetAttr("dropout_prob", dropout_prob); instr.SetAttr("dropout_implementation", dropout_implementation); + InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); } diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 2543f9b48527f..ac7cb445295bb 100644 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -69,12 +69,11 @@ TEST(net_build, program_execute_multi_elementwise_add) { #else Target target = common::DefaultHostTarget(); #endif + auto graph = std::make_shared(program, target); LOG(INFO) << "graph:\n" << graph->Visualize(); - hlir::framework::ApplyPass(graph.get(), "InferShape"); auto scope = BuildScope(target, graph); - hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); @@ -109,11 +108,9 @@ TEST(net_build, program_execute_fc) { #else Target target = common::DefaultHostTarget(); #endif - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "InferShape"); + auto graph = std::make_shared(program, target); auto scope = BuildScope(target, graph); - hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc index 9543e44fecc7f..fff87fe30d08a 100644 --- a/cinn/hlir/framework/graph.cc +++ b/cinn/hlir/framework/graph.cc @@ -29,7 +29,9 @@ Graph::Graph(const frontend::Program& prog, const Target& target) { } int out_idx = 0; for (auto& output_v : temp->outputs) { - auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id); + dtype_dict[output_v->id] = output_v->type; + shape_dict[output_v->id] = output_v->shape; + auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id); node_tmp->LinkTo(output_data); this->RegisterNode(output_v->id, output_data); } diff --git a/cinn/hlir/framework/node.h b/cinn/hlir/framework/node.h index 18ee1eda72d2e..94ad4873f9bbd 100644 --- a/cinn/hlir/framework/node.h +++ b/cinn/hlir/framework/node.h @@ -1,14 +1,13 @@ #pragma once +#include +#include + #include #include #include -#include #include #include -#include -#include - #include "cinn/common/graph_utils.h" #include "cinn/common/shared.h" #include "cinn/hlir/framework/op.h" @@ -19,15 +18,16 @@ namespace framework { class Node; class NodeData; -using NodePtr = std::shared_ptr; -using AttrType = absl::variant, - std::vector, - std::vector, - std::vector>; +using NodePtr = std::shared_ptr; +using AttrType = absl::variant, + std::vector, + std::vector, + std::vector>; +using AttrMapType = absl::flat_hash_map; /** * \brief Attributes of each node in graph. @@ -93,7 +93,7 @@ class Node : public common::GraphNode { inline uint32_t num_inputs() { return is_variable() ? 1 : this->op()->num_inputs; } template - static NodePtr Create(Args &&... args) { + static NodePtr Create(Args &&...args) { return std::make_shared(std::forward(args)...); } @@ -125,7 +125,7 @@ class NodeData : public common::GraphNode { const char *op_name, std::string node_name, std::vector inputs, - std::string id = nullptr, + std::string id = nullptr, absl::flat_hash_map attrs = absl::flat_hash_map()) { auto res = std::make_shared(); res->id_ = std::move(id); diff --git a/cinn/hlir/framework/op.h b/cinn/hlir/framework/op.h index d34b696fc7825..1fb73435c73da 100644 --- a/cinn/hlir/framework/op.h +++ b/cinn/hlir/framework/op.h @@ -1,4 +1,6 @@ #pragma once +#include +#include #include #include @@ -6,22 +8,24 @@ #include #include //NOLINT #include -#include #include #include -#include - #include "cinn/common/macros.h" #include "cinn/utils/registry.h" +template +inline auto MakeOpFunction(R (*func)(Args...)) { + return std::function(func); +} + namespace cinn { namespace hlir { namespace framework { class Operator; using shape_t = std::vector; -using dim_t = shape_t ::value_type; +using dim_t = shape_t::value_type; /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { @@ -70,6 +74,8 @@ class OpValueType { inline bool Find(const Operator* op) const; + size_t Size() const { return data.size(); } + private: friend class Operator; std::string attr_name; @@ -203,11 +209,6 @@ bool OpValueType::Find(const Operator* op) const { return idx < data.size(); } -template -inline auto MakeOpFunction(R(*func)(Args...)) { - return std::function(func); -} - // internal macros to make #define CINN_REGISTER_VAR_DEF(OpName) static ::cinn::hlir::framework::Operator& __make_##HlirOp##_##OpName @@ -228,8 +229,6 @@ inline auto MakeOpFunction(R(*func)(Args...)) { CINN_STR_CONCAT(CINN_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::cinn::hlir::framework::OpRegistry::Global()->__REGISTER_OR_GET__(#OpName) - - } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_strategy.h b/cinn/hlir/framework/op_strategy.h index 5e7957b915a30..b505483d99752 100644 --- a/cinn/hlir/framework/op_strategy.h +++ b/cinn/hlir/framework/op_strategy.h @@ -17,13 +17,13 @@ using CINNSchedule = lang::PackedFunc; class OpStrategy; -using StrategyFunction = std::function(const NodeAttr&, +using StrategyFunction = std::function(const NodeAttr&, const std::vector&, const std::vector&, const std::vector>&, const common::Target&)>; -using InferShapeFunction = std::function>( - const std::vector>&, NodeAttr&, const common::Target&)>; +using InferShapeFunction = + std::function>(const std::vector>&, const AttrMapType&)>; //! Operator implementation that includes compute and schedule function. class OpImpl : public common::Object { diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 5a4a3fdf0930f..654662b3d52fb 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -20,15 +20,14 @@ using common::CINNValuePack; using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -using namespace pe; -#define StrategyForBinary(op_name__, pe__) \ - std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ - const std::vector &inputs, \ - const std::vector &out_type, \ - const std::vector> &output_shapes, \ - const Target &target) { \ - return StrategyForBroadcast(attrs, inputs, out_type, output_shapes, target, #op_name__, pe__); \ +#define StrategyForBinary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForBroadcast(attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ } std::shared_ptr StrategyForBroadcast( @@ -83,13 +82,12 @@ std::shared_ptr StrategyForBroadcast( } std::vector InferShapeForBroadcast(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2UL); std::vector out_shape; int axis = -1; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "axis") { axis = absl::get(iter.second); break; @@ -102,9 +100,7 @@ std::vector InferShapeForBroadcast(const std::vector &inputs_s return {out_shape}; } -std::vector InferDtypeForBroadcast(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForBroadcast(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -168,7 +164,7 @@ std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &at Expr A_expr = a[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); - auto out = BroadcastTo(A, out_shape, broadcast_axes, UniqName("broadcast_to_Out")); + auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, UniqName("broadcast_to_Out")); auto stages = CreateStages({A, out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); @@ -195,15 +191,14 @@ std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &at } std::vector InferShapeForBroadcastTo(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1UL) << "input_shape size should be one. Please Check."; std::vector broadcast_axes; std::vector out_shape; - CHECK(attrs.attr_store.count("broadcast_axes")); - CHECK(attrs.attr_store.count("out_shape")); - out_shape = absl::get>(attrs.attr_store.at("out_shape")); - broadcast_axes = absl::get>(attrs.attr_store.at("broadcast_axes")); + CHECK(attrs.count("broadcast_axes")); + CHECK(attrs.count("out_shape")); + out_shape = absl::get>(attrs.at("out_shape")); + broadcast_axes = absl::get>(attrs.at("broadcast_axes")); CHECK_EQ(inputs_shape[0].size(), broadcast_axes.size()) << "broadcast_axes's size should be same with the input shape's size"; @@ -258,7 +253,6 @@ StrategyForBinary(right_shift, RightShift); } // namespace hlir } // namespace cinn - CINN_REGISTER_HELPER(broadcast_ops) { #define CINN_REGISTER_BINARY(op__, op_stragegy__) \ CINN_REGISTER_OP(op__) \ @@ -266,9 +260,9 @@ CINN_REGISTER_HELPER(broadcast_ops) { .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) \ + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ .set_support_level(4); diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc index bae102ac6f271..8f10605485748 100644 --- a/cinn/hlir/op/elementwise.cc +++ b/cinn/hlir/op/elementwise.cc @@ -18,16 +18,15 @@ using common::CINNValuePack; using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -using namespace pe; using PeFunc = std::function(const ir::Tensor &A, const std::string &out_name)>; -#define StrategyForUnary(op_name__, pe__) \ - std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ - const std::vector &inputs, \ - const std::vector &out_type, \ - const std::vector> &output_shapes, \ - const Target &target) { \ - return StrategyForElementwise(attrs, inputs, out_type, output_shapes, target, #op_name__, pe__); \ +#define StrategyForUnary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForElementwise(attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ } std::shared_ptr StrategyForElementwise(const framework::NodeAttr &attrs, @@ -77,16 +76,13 @@ std::shared_ptr StrategyForElementwise(const framework::NodeAttr &at } std::vector InferShapeForElementwise(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1UL); std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForElementwise(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForElementwise(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -213,16 +209,13 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at } std::vector InferShapeForConstScalar(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { return {{1}}; } -std::vector InferDtypeForConstScalar(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK(attrs.attr_store.count("value")); - auto scalar = GetScalarExpr(attrs.attr_store.at("value")); +std::vector InferDtypeForConstScalar(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(attrs.count("value")); + auto scalar = GetScalarExpr(attrs.at("value")); auto out_type = scalar->type(); VLOG(3) << "scalar type: " << out_type; return {out_type}; @@ -284,9 +277,9 @@ CINN_REGISTER_HELPER(elementwise_ops) { .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \ + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) \ .set_support_level(4); diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index c8425aec24261..3c606ea3932a1 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -67,16 +67,13 @@ std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, } std::vector InferShapeForRelu(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForRelu(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForRelu(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -333,24 +330,23 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, } std::vector InferShapeForConv2d(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); std::string data_format = "NCHW"; - if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { - padding = absl::get>(attrs.attr_store.at("padding")); + if (attrs.find("padding") != attrs.end()) { + padding = absl::get>(attrs.at("padding")); } - if (attrs.attr_store.find("stride") != attrs.attr_store.end()) { - stride = absl::get>(attrs.attr_store.at("stride")); + if (attrs.find("stride") != attrs.end()) { + stride = absl::get>(attrs.at("stride")); } - if (attrs.attr_store.find("dilation") != attrs.attr_store.end()) { - dilation = absl::get>(attrs.attr_store.at("dilation")); + if (attrs.find("dilation") != attrs.end()) { + dilation = absl::get>(attrs.at("dilation")); } - if (attrs.attr_store.find("data_format") != attrs.attr_store.end()) { - data_format = absl::get(attrs.attr_store.at("data_format")); + if (attrs.find("data_format") != attrs.end()) { + data_format = absl::get(attrs.at("data_format")); } CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d op is not 2! Please check."; CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d op is not 2! Please check."; @@ -378,7 +374,6 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap int pad_w = padding[1]; std::string key = pe::GenerateX86ConvKey(inputs_shape[0], inputs_shape[1], stride, padding, dilation); VLOG(3) << "key: " << key; - attrs.attr_store["key"] = key; pe::GetConv2dFactors(&conv2d_factors, oc, ic, fc, -1, -1, Float(32), common::DefaultHostTarget(), key); int ic_bn = conv2d_factors["ic_bn"]; int oc_bn = conv2d_factors["oc_bn"]; @@ -413,9 +408,7 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap return res; } -std::vector InferDtypeForConv2d(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForConv2d(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; #ifdef CINN_WITH_CUDA std::vector res{inputs_type[0]}; @@ -552,24 +545,23 @@ std::shared_ptr StrategyForConv2dNCHWc(const framework::NodeAttr &at } std::vector InferShapeForConv2dNCHWc(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); std::string data_format = "NCHWc"; - if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { - padding = absl::get>(attrs.attr_store.at("padding")); + if (attrs.find("padding") != attrs.end()) { + padding = absl::get>(attrs.at("padding")); } - if (attrs.attr_store.find("stride") != attrs.attr_store.end()) { - stride = absl::get>(attrs.attr_store.at("stride")); + if (attrs.find("stride") != attrs.end()) { + stride = absl::get>(attrs.at("stride")); } - if (attrs.attr_store.find("dilation") != attrs.attr_store.end()) { - dilation = absl::get>(attrs.attr_store.at("dilation")); + if (attrs.find("dilation") != attrs.end()) { + dilation = absl::get>(attrs.at("dilation")); } - if (attrs.attr_store.find("data_format") != attrs.attr_store.end()) { - data_format = absl::get(attrs.attr_store.at("data_format")); + if (attrs.find("data_format") != attrs.end()) { + data_format = absl::get(attrs.at("data_format")); } CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d_NCHWc op is not 2! Please check."; CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d_NCHWc op is not 2! Please check."; @@ -618,9 +610,7 @@ std::vector> InferLayoutForConv2dNCHWc(const std::vecto return {{outlayout, outlayout, input_layouts[0]}, input_layouts}; } -std::vector InferDtypeForConv2dNCHWc(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForConv2dNCHWc(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0], inputs_type[0]}; return res; @@ -773,22 +763,21 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr } std::vector InferShapeForDepthwiseConv2d(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2U) << "at least 2 input tensors for depthwise_conv2d op\n"; CHECK_EQ(inputs_shape[0].size(), 4U) << "The input tensor's shape should be 4! Please check again."; CHECK_EQ(inputs_shape[1].size(), 4U) << "The input tensor's shape should be 4! Please check again."; std::vector padding = {0, 0}; std::vector stride = {1, 1}; std::string data_format = "NCHW"; - if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { - padding = absl::get>(attrs.attr_store.at("padding")); + if (attrs.find("padding") != attrs.end()) { + padding = absl::get>(attrs.at("padding")); } - if (attrs.attr_store.find("stride") != attrs.attr_store.end()) { - stride = absl::get>(attrs.attr_store.at("stride")); + if (attrs.find("stride") != attrs.end()) { + stride = absl::get>(attrs.at("stride")); } - if (attrs.attr_store.find("data_format") != attrs.attr_store.end()) { - data_format = absl::get(attrs.attr_store.at("data_format")); + if (attrs.find("data_format") != attrs.end()) { + data_format = absl::get(attrs.at("data_format")); } std::vector res; CHECK_EQ(padding.size(), 2U) << "The size of padding in depthwise_conv2d op is not 2! Please check."; @@ -810,8 +799,7 @@ std::vector InferShapeForDepthwiseConv2d(const std::vector &in } std::vector InferDtypeForDepthwiseConv2d(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -899,16 +887,13 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr } std::vector InferShapeForBatchNorm(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForBatchNorm(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForBatchNorm(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -1015,10 +1000,8 @@ std::shared_ptr StrategyForPool1d(const framework::NodeAttr &attrs, } std::vector> InferShapeForPool1d(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; - auto attr_store = attrs.attr_store; std::vector kernel_size; // [kernel_w] std::vector stride_size; // [stride_w] std::vector padding_size; // [padding_left, padding_right] @@ -1026,7 +1009,7 @@ std::vector> InferShapeForPool1d(const std::vector>(iter.second); } else if (iter.first == "stride_size") { @@ -1192,11 +1175,9 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, } std::vector> InferShapeForPool2d(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(inputs_shape[0].size() == 4 || inputs_shape[0].size() == 5) << "The input's shape size of pool2d should be 4 or 5! Please check again."; - auto attr_store = attrs.attr_store; std::vector kernel_size; std::vector stride_size; std::vector padding_size; @@ -1206,7 +1187,7 @@ std::vector> InferShapeForPool2d(const std::vector>(iter.second); } else if (iter.first == "stride_size") { @@ -1267,7 +1248,7 @@ std::vector> InferShapeForPool2d(const std::vector>(attr_store["kernel_size"]); + kernel_size = absl::get>(attrs.at("kernel_size")); if (kernel_size.size() == 1UL) kernel_size.push_back(kernel_size[0]); CHECK(kernel_size.size() >= 2UL) << "In pool2d, kernel_size's size should be >= 2, please check!"; output_shape1[height_axis] = kernel_size[0]; @@ -1369,10 +1350,8 @@ std::shared_ptr StrategyForPool3d(const framework::NodeAttr &attrs, } std::vector> InferShapeForPool3d(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; - auto attr_store = attrs.attr_store; std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] std::vector stride_size; // [stride_d, stride_h, stride_w] std::vector @@ -1381,7 +1360,7 @@ std::vector> InferShapeForPool3d(const std::vector>(iter.second); } else if (iter.first == "stride_size") { @@ -1443,9 +1422,7 @@ std::vector> InferShapeForPool3d(const std::vector InferDtypeForPool(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForPool(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -1533,16 +1510,13 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, } std::vector> InferShapeForSoftmax(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector> res{inputs_shape[0], inputs_shape[0]}; return res; } -std::vector InferDtypeForSoftmax(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForSoftmax(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0]}; return res; @@ -1628,13 +1602,12 @@ std::shared_ptr StrategyForSlice(const framework::NodeAttr &attrs, } std::vector> InferShapeForSlice(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector starts; std::vector ends; std::vector axes; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "starts") { starts = absl::get>(iter.second); } else if (iter.first == "ends") { @@ -1675,9 +1648,7 @@ std::vector> InferShapeForSlice(const std::vector InferDtypeForSlice(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForSlice(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -1767,12 +1738,11 @@ std::shared_ptr StrategyForDropoutInfer(const framework::NodeAttr &a } std::vector> InferShapeForDropoutInfer(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; float dropout_prob = 0; std::string dropout_implementation = "downgrade_in_infer"; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "dropout_prob") { dropout_prob = absl::get(iter.second); } else if (iter.first == "dropout_implementation") { @@ -1786,9 +1756,7 @@ std::vector> InferShapeForDropoutInfer(const std::vector InferDtypeForDropoutInfer(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForDropoutInfer(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -1842,8 +1810,7 @@ std::shared_ptr StrategyForSelect(const framework::NodeAttr &attrs, } std::vector InferShapeForSelect(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_GE(inputs_shape.size(), 3) << "The input's shape size is 0! Please check again."; CHECK(inputs_shape[0].size() == inputs_shape[1].size() && inputs_shape[1].size() == inputs_shape[2].size()) << "input tensors n_dim is not equal!"; @@ -1853,9 +1820,7 @@ std::vector InferShapeForSelect(const std::vector InferDtypeForSelect(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForSelect(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK_GE(inputs_type.size(), 3) << "The input's type size is less than three! Please check again."; CHECK(inputs_type[0].is_bool()) << "The condition tensor type should be bool"; std::vector res{inputs_type[1]}; diff --git a/cinn/hlir/op/op_nn_test.cc b/cinn/hlir/op/op_nn_test.cc index 6bbb692c99859..73aa78e8b78a7 100644 --- a/cinn/hlir/op/op_nn_test.cc +++ b/cinn/hlir/op/op_nn_test.cc @@ -321,7 +321,7 @@ TEST(Operator, Operator_Select_Test0) { const common::Target target = common::DefaultHostTarget(); const std::vector input_shapes = {{16, 64, 64}, {16, 64, 64}, {16, 64, 64}}; - auto infer_shape = infer_shape_func(input_shapes, attrs, target); + auto infer_shape = infer_shape_func(input_shapes, attrs.attr_store); ASSERT_EQ(infer_shape[0][0], 16); ASSERT_EQ(infer_shape[0][1], 64); ASSERT_EQ(infer_shape[0][2], 64); diff --git a/cinn/hlir/op/reduction.cc b/cinn/hlir/op/reduction.cc index 5eff11010bf0e..12984fe5310db 100644 --- a/cinn/hlir/op/reduction.cc +++ b/cinn/hlir/op/reduction.cc @@ -85,24 +85,21 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, } std::vector InferShapeForReduction(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1UL); std::vector dim; bool keep_dim = false; - if (attrs.attr_store.find("dim") != attrs.attr_store.end()) { - dim = absl::get>(attrs.attr_store.at("dim")); + if (attrs.find("dim") != attrs.end()) { + dim = absl::get>(attrs.at("dim")); } - if (attrs.attr_store.find("keep_dim") != attrs.attr_store.end()) { - keep_dim = absl::get(attrs.attr_store.at("keep_dim")); + if (attrs.find("keep_dim") != attrs.end()) { + keep_dim = absl::get(attrs.at("keep_dim")); } std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForReduction(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForReduction(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -134,9 +131,9 @@ CINN_REGISTER_HELPER(reduce_ops) { .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) \ + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kCommReduce) \ .set_support_level(4); diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 94039bed36d78..ec58226a00a40 100644 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -212,8 +212,7 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, } std::vector> InferShapeForMatMul(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; std::vector output_shape; std::vector new_shape_A; @@ -221,7 +220,7 @@ std::vector> InferShapeForMatMul(const std::vector(iter.second); } else if (iter.first == "trans_b") { @@ -250,9 +249,7 @@ std::vector> InferShapeForMatMul(const std::vector InferDtypeForMatMul(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForMatMul(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0], inputs_type[0]}; return res; @@ -324,18 +321,19 @@ std::shared_ptr StrategyForReshape(const framework::NodeAttr &attrs, } std::vector> InferShapeForReshape(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; std::vector output_shape; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "shape") { output_shape = absl::get>(iter.second); break; } } int tensor_size = 1; - for (auto i : inputs_shape[0]) tensor_size *= i; + for (auto i : inputs_shape[0]) { + tensor_size *= i; + } CHECK(!output_shape.empty()) << "infer_shape for reshape turns out to be empty. Please check\n"; int flag_index = -1; for (int i = 0; i < output_shape.size(); i++) { @@ -364,9 +362,7 @@ std::vector> InferShapeForReshape(const std::vector InferDtypeForReshape(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForReshape(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -447,11 +443,10 @@ std::shared_ptr StrategyForConcat(const framework::NodeAttr &attrs, } std::vector> InferShapeForConcat(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; int axis = 0; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "axis") { axis = absl::get(iter.second); break; @@ -468,9 +463,7 @@ std::vector> InferShapeForConcat(const std::vector InferDtypeForConcat(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForConcat(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -689,8 +682,7 @@ std::shared_ptr StrategyForMulBias(const framework::NodeAttr &attrs, } std::vector> InferShapeForMul(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { // CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; CHECK_GE(inputs_shape[0].size(), 2U) << "Input matrix X's dim should be >= 2! Please check."; CHECK_GE(inputs_shape[1].size(), 2U) << "Input matrix Y's dim should be >= 2! Please check."; @@ -698,7 +690,7 @@ std::vector> InferShapeForMul(const std::vector output_shape; int x_num_col_dims = 1; int y_num_col_dims = 1; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "x_num_col_dims") { x_num_col_dims = absl::get(iter.second); } else if (iter.first == "y_num_col_dims") { @@ -738,9 +730,7 @@ std::vector> InferShapeForMul(const std::vector InferDtypeForMul(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForMul(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0]}; return res; @@ -764,8 +754,7 @@ std::vector> InferLayoutForMul(const std::vector> InferShapeForMulBias(const std::vector> &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { // CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; CHECK_GE(inputs_shape[0].size(), 2U) << "Input matrix X's dim should be >= 2! Please check."; CHECK_GE(inputs_shape[1].size(), 2U) << "Input matrix Y's dim should be >= 2! Please check."; @@ -773,7 +762,7 @@ std::vector> InferShapeForMulBias(const std::vector output_shape; int x_num_col_dims = 1; int y_num_col_dims = 1; - for (auto &iter : attrs.attr_store) { + for (auto &iter : attrs) { if (iter.first == "x_num_col_dims") { x_num_col_dims = absl::get(iter.second); } else if (iter.first == "y_num_col_dims") { @@ -807,9 +796,7 @@ std::vector> InferShapeForMulBias(const std::vector InferDtypeForMulBias(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector InferDtypeForMulBias(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0]}; return res; @@ -872,15 +859,14 @@ std::shared_ptr StrategyForLayoutTransform(const framework::NodeAttr } std::vector InferShapeForLayoutTransform(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { std::string src_layout; std::string dst_layout; - if (attrs.attr_store.find("src_layout") != attrs.attr_store.end()) { - src_layout = absl::get(attrs.attr_store.at("src_layout")); + if (attrs.find("src_layout") != attrs.end()) { + src_layout = absl::get(attrs.at("src_layout")); } - if (attrs.attr_store.find("dst_layout") != attrs.attr_store.end()) { - dst_layout = absl::get(attrs.attr_store.at("dst_layout")); + if (attrs.find("dst_layout") != attrs.end()) { + dst_layout = absl::get(attrs.at("dst_layout")); } CHECK_EQ(inputs_shape.size(), 1UL); @@ -900,8 +886,7 @@ std::vector InferShapeForLayoutTransform(const std::vector &in } std::vector InferDtypeForLayoutTransform(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -968,12 +953,11 @@ std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, } std::vector InferShapeForReverse(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; std::vector res{inputs_shape[0]}; - if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { - auto axis = absl::get>(attrs.attr_store.at("axis")); + if (attrs.find("axis") != attrs.end()) { + auto axis = absl::get>(attrs.at("axis")); CHECK(!axis.empty()) << "axis is empty! Please check setting.\n"; for (auto &e : axis) { if (e >= inputs_shape[0].size() || e < -1 * static_cast(inputs_shape[0].size())) { diff --git a/cinn/hlir/pass/alterlayout.cc b/cinn/hlir/pass/alterlayout.cc index bd6f8e02147c8..5fc9020ce6a31 100644 --- a/cinn/hlir/pass/alterlayout.cc +++ b/cinn/hlir/pass/alterlayout.cc @@ -19,10 +19,9 @@ using framework::NodeData; using framework::Operator; using framework::OpValueType; -using InferShapeFunc = std::function( - const std::vector&, framework::NodeAttr&, const Target&)>; -using InferTypeFunc = - std::function(const std::vector&, const framework::NodeAttr&, const Target&)>; +using InferShapeFunc = std::function(const std::vector&, + const framework::AttrMapType&)>; +using InferTypeFunc = std::function(const std::vector&, const framework::AttrMapType&)>; using InferLayoutFunc = std::function>(const std::vector&, const std::vector&, const framework::NodeAttr&, @@ -68,7 +67,7 @@ std::tuple InsertLayoutTransformNodeBefore(Graph* graph, return std::make_tuple(trans_node, temp_outdata); } -std::vector updateInferInfos(Node* node, +std::vector UpdateInferInfos(Node* node, const std::vector& input_shapes, const std::vector& input_types, const std::vector& input_layouts, @@ -85,15 +84,14 @@ std::vector updateInferInfos(Node* node, CHECK(op_infershape[node->op()]) << "find no InferShape function for op " << node->op()->name; CHECK(op_infertype[node->op()]) << "find no InferDtype function for op " << node->op()->name; CHECK(op_inferlayout[node->op()]) << "find no InferLayout function for op " << node->op()->name; - auto infershapes = op_infershape[node->op()](input_shapes, node->attrs, target); - auto infertypes = op_infertype[node->op()](input_types, node->attrs, target); + auto infershapes = op_infershape[node->op()](input_shapes, node->attrs.attr_store); + auto infertypes = op_infertype[node->op()](input_types, node->attrs.attr_store); auto inferlayouts = op_inferlayout[node->op()](input_shapes, input_layouts, node->attrs, target); CHECK(!infershapes.empty()) << node->op()->name << " finds no infershape"; CHECK(!infertypes.empty()) << node->op()->name << " finds no infertype"; CHECK(!inferlayouts.empty()) << node->op()->name << " finds no inferlayout"; auto outlinks = node->outlinks_in_order(true); - // check opt CHECK_EQ(infershapes.size(), infertypes.size()); CHECK_EQ(inferlayouts.size(), 2U); CHECK_EQ(infertypes.size(), inferlayouts[0].size()); @@ -122,6 +120,35 @@ void AlterLayoutPass(Graph* graph) { auto& op_inferdtype = Operator::GetAttrs("inferdtype"); auto& op_inferlayout = Operator::GetAttrs("inferlayout"); absl::flat_hash_map layout_dict; + // collect all convs' original input config before altering layout for loading tune params afterwards + for (int i = 0; i < store_nodes.size(); i++) { + auto node = store_nodes[i]->safe_as(); + if (node && node->op()->name == "conv2d") { + std::vector padding({0, 0}); + std::vector stride({1, 1}); + std::vector dilation({1, 1}); + if (node->attrs.attr_store.find("padding") != node->attrs.attr_store.end()) { + padding = absl::get>(node->attrs.attr_store.at("padding")); + } + if (node->attrs.attr_store.find("stride") != node->attrs.attr_store.end()) { + stride = absl::get>(node->attrs.attr_store.at("stride")); + } + if (node->attrs.attr_store.find("dilation") != node->attrs.attr_store.end()) { + dilation = absl::get>(node->attrs.attr_store.at("dilation")); + } + auto& conv_inlinks = node->inlinks_in_order(true); + CHECK_EQ(conv_inlinks.size(), 2U) << "conv2d should have 2 inputs"; + std::vector> inputs_shape; + for (auto& link : conv_inlinks) { + auto* source = link->source(); + CHECK(shape_dict.count(source->id())) << source->id() << " finds no infershape"; + inputs_shape.push_back(shape_dict.at(source->id())); + } + std::string key = pe::GenerateX86ConvKey(inputs_shape[0], inputs_shape[1], stride, padding, dilation); + VLOG(3) << "key: " << key; + node->attrs.attr_store["key"] = key; + } + } bool has_altered = false; for (int i = 0; i < store_nodes.size(); i++) { @@ -202,7 +229,7 @@ void AlterLayoutPass(Graph* graph) { src_input_layout, dst_input_layout, common::UniqName(node->op()->name + "_input_layout_tranform")); - updateInferInfos(input_trans_node, + UpdateInferInfos(input_trans_node, {input_shape}, {input_type}, {src_input_layout}, @@ -244,7 +271,7 @@ void AlterLayoutPass(Graph* graph) { dst_kernel_layout, common::UniqName(node->op()->name + "_weight_layout_tranform")); weight_trans_node->attrs.attr_store["pre_run"] = true; - updateInferInfos(weight_trans_node, + UpdateInferInfos(weight_trans_node, {weight_shape}, {weight_type}, {src_kernel_layout}, @@ -270,7 +297,7 @@ void AlterLayoutPass(Graph* graph) { conv2d_NCHWc_inputlayouts.push_back(layout_dict[weight_node->id()]); } // replace conv2d to conv2d_NCHWc - auto infershapes = op_infershape[new_node->op()](conv2d_NCHWc_inputshapes, new_node->attrs, graph->target_); + auto infershapes = op_infershape[new_node->op()](conv2d_NCHWc_inputshapes, new_node->attrs.attr_store); auto& old_inlinks = node->inlinks_in_order(true); auto& old_outlinks = node->outlinks_in_order(true); for (auto& link : old_inlinks) { @@ -301,7 +328,7 @@ void AlterLayoutPass(Graph* graph) { } graph->RegisterNode(new_node->id(), new_node); // update conv2d_NCHWc's infershape, infertype, inferlayout and set attrs - updateInferInfos(new_node, + UpdateInferInfos(new_node, conv2d_NCHWc_inputshapes, conv2d_NCHWc_inputtypes, conv2d_NCHWc_inputlayouts, @@ -379,7 +406,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, new_input_layouts[i], common::UniqName(source->id() + "_layout_tranform")); - updateInferInfos(trans_node, + UpdateInferInfos(trans_node, {new_shapes}, {input_types[i]}, {src_layout}, @@ -409,7 +436,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, new_input_layouts[i], common::UniqName(source->id() + "_layout_tranform")); - updateInferInfos(trans_node, + UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, {src_layout}, @@ -439,7 +466,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, new_input_layouts[i], common::UniqName(source->id() + "_layout_tranform")); - updateInferInfos(trans_node, + UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, {src_layout}, @@ -471,7 +498,7 @@ void AlterLayoutPass(Graph* graph) { input_layouts.push_back(""); } } - updateInferInfos(node, + UpdateInferInfos(node, input_shapes, input_types, input_layouts, @@ -521,7 +548,7 @@ void AlterLayoutPass(Graph* graph) { shape_dict[temp_out->id()] = shape; type_dict[temp_out->id()] = type; layout_dict[temp_out->id()] = src_layout; - updateInferInfos(trans_node, + UpdateInferInfos(trans_node, {shape}, {type}, {src_layout}, diff --git a/cinn/hlir/pass/infershape.cc b/cinn/hlir/pass/infershape.cc old mode 100755 new mode 100644 index cd137abd031df..8adb772d68f78 --- a/cinn/hlir/pass/infershape.cc +++ b/cinn/hlir/pass/infershape.cc @@ -21,10 +21,10 @@ void InferShapePass(Graph* graph) { auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); auto store_nodes = std::get<0>(graph->topological_order()); auto& op_infershape = Operator::GetAttrs( - const std::vector&, framework::NodeAttr&, const Target&)>>("infershape"); - auto& op_inferdtype = Operator::GetAttrs< - std::function(const std::vector&, const framework::NodeAttr&, const Target&)>>( - "inferdtype"); + const std::vector&, const framework::AttrMapType&)>>("infershape"); + auto& op_inferdtype = + Operator::GetAttrs(const std::vector&, const framework::AttrMapType&)>>( + "inferdtype"); for (auto& n : store_nodes) { auto node = n->safe_as(); @@ -41,9 +41,9 @@ void InferShapePass(Graph* graph) { } auto out_shape = - op_infershape[node->safe_as()->op()](inputs_shape, node->safe_as()->attrs, graph->target_); + op_infershape[node->safe_as()->op()](inputs_shape, node->safe_as()->attrs.attr_store); auto out_dtype = - op_inferdtype[node->safe_as()->op()](inputs_dtype, node->safe_as()->attrs, graph->target_); + op_inferdtype[node->safe_as()->op()](inputs_dtype, node->safe_as()->attrs.attr_store); CHECK_GE(node->outlinks_in_order().size(), out_shape.size()) << "The output number of node " << node->id() << " is " << node->outlinks_in_order().size() @@ -66,8 +66,6 @@ void InferShapePass(Graph* graph) { } } } - graph->attrs["infershape"] = std::make_shared(shape_dict); - graph->attrs["inferdtype"] = std::make_shared(dtype_dict); } } // namespace pass diff --git a/cinn/pybind/framework.cc b/cinn/pybind/framework.cc index 27d28a77a959f..63fe670060ef4 100644 --- a/cinn/pybind/framework.cc +++ b/cinn/pybind/framework.cc @@ -3,7 +3,6 @@ #include #include -#include "cinn/pybind/bind.h" #include "cinn/common/cinn_value.h" #include "cinn/frontend/interpreter.h" #include "cinn/hlir/framework/node.h" @@ -11,6 +10,7 @@ #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/framework/scope.h" #include "cinn/hlir/op/use_ops.h" +#include "cinn/pybind/bind.h" namespace cinn::pybind { @@ -59,10 +59,9 @@ void BindFramework(pybind11::module *m) { [](OpValueType &self, const std::string &key, const std::vector> &input_shapes, - NodeAttr &attrs, - const Target &target) { + const AttrMapType &attrs) { const Operator *op_ptr = Operator::Get(key); - auto shapes = self[op_ptr](input_shapes, attrs, target); + auto shapes = self[op_ptr](input_shapes, attrs); return shapes; }); diff --git a/python/tests/test_utils.py b/python/tests/test_utils.py index 322dfdcb66519..99290bfd33cf4 100644 --- a/python/tests/test_utils.py +++ b/python/tests/test_utils.py @@ -102,8 +102,8 @@ def to_test_op(self, runtime.cinn_x86_device, alignment)) if do_infer_shape: infer_shapes = framework.Operator.get_op_shape_attrs("infershape") - out_shapes = infer_shapes.infer_shape(op_name, input_shapes, attrs, - self.target) + out_shapes = infer_shapes.infer_shape(op_name, input_shapes, + attrs.attr_store) print("out_shapes", out_shapes) for out_shape in out_shapes[1:]: out.append(