Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Tricky Branch #1469

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ VisitForMultiOperandsDtypePattern(Product, mul);
VisitCountMemberPattern(And, bool_op);
VisitCountMemberPattern(Or, bool_op);
VisitCountMemberPattern(Not, bool_op);
VisitCountMemberPattern(GetReference, mem_read);
VisitCountMemberPattern(Max, select_op);
VisitCountMemberPattern(Min, select_op);
VisitCountMemberPattern(IfThenElse, select_op);
Expand Down
4 changes: 4 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ std::string CodeGenC::GetTypeRepr(Type type) {
str += "*";
} else if (type.is_cpp_handle2()) {
str += "**";
} else if (type.is_cpp_reference()) {
str += "&";
}

return str;
}
void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); }
Expand Down Expand Up @@ -186,6 +189,7 @@ void CodeGenC::Visit(const ir::Not *op) {
IrPrinter::Print(op->v());
os() << ")";
}
void CodeGenC::Visit(const ir::GetReference *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::For *op) {
Expr extent = op->extent;
Expand Down
6 changes: 6 additions & 0 deletions cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFu
}

void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
std::set<Expr> device_count_exprs = op->PrepareDeviceCountExprs();
for (auto dce : device_count_exprs) {
VLOG(6) << "PrepareDeviceCountExprs " << dce;
os() << "__device__ int " << dce.As<ir::_Var_>()->name << " = 0;\n";
}

// clear names valid within scope when enter a new function
vectorized_tensor_names_.clear();
os() << "__global__\n";
Expand Down
5 changes: 5 additions & 0 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Minus *op) {
return (op->type().is_int() || op->type().is_uint()) ? Neg(v) : FNeg(v);
}

llvm::Value *CodeGenLLVM::Visit(const ir::GetReference *op) {
LOG(FATAL) << "TODO: Unimplementd CodeGenLLVM::Visit(const ir::GetReference *op)";
return nullptr;
}

llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { return Not(Visit(&op->v())); }

llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
Expand Down
18 changes: 18 additions & 0 deletions cinn/common/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ Type &Type::set_cpp_handle2(bool x) {
return *this;
}

Type &Type::set_cpp_reference(bool x) {
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));

// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);

if (x)
v |= static_cast<uint8_t>(cpp_type_t::Reference);
else
v &= ~static_cast<uint8_t>(cpp_type_t::Reference);

return *this;
}

Type Type::VectorOf(int w) const {
CheckTypeValid();
return Type(type(), bits(), w, specific_type());
Expand Down Expand Up @@ -263,6 +278,9 @@ bool Type::is_cpp_handle() const {
bool Type::is_cpp_handle2() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) & static_cast<uint8_t>(cpp_type_t::HandleHandle);
}
bool Type::is_cpp_reference() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) & static_cast<uint8_t>(cpp_type_t::Reference);
}
bool Type::is_cpp_const() const {
return static_cast<uint8_t>(cpp_type_t::Const) & static_cast<uint8_t>(GetStorage().cpp_type_);
}
Expand Down
4 changes: 4 additions & 0 deletions cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct Type {
Const = 1, // const.
Handle = 1 << 1, // pointer type, such as `cinn_buffer_t*`.
HandleHandle = 1 << 2, // pointer of pointer, such as `cinn_buffer_t**`.
Reference = 1 << 4, // reference type, such as `cinn_buffer_t&`.
};

Type();
Expand Down Expand Up @@ -100,6 +101,9 @@ struct Type {
Type& set_cpp_handle2(bool x = true);
CINN_NODISCARD bool is_cpp_handle2() const;

Type& set_cpp_reference(bool x = true);
CINN_NODISCARD bool is_cpp_reference() const;

Type& set_cpp_const(bool is_const = true);
CINN_NODISCARD bool is_cpp_const() const;

Expand Down
7 changes: 7 additions & 0 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(IRComputeFunction compute,
ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ true);
} else {
for (auto& sub_group : group->fused_sub_groups) {
VLOG(4) << "sub_group->group_id = " << sub_group->group_id;
auto exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ true);
ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end());
}
Expand Down Expand Up @@ -1277,8 +1278,14 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}
VLOG(3) << "Before loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
VLOG(4) << " FUSION " << node->op()->name;
Node* fusion_master = master ? master : nodes_in_order.front();

// if (CanFuseReduceByBlockSync(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map)) {
// SyncGpuBlocks(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map);
// } else {
// do loop fuse.
LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map);
//}
VLOG(3) << "After loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
}

Expand Down
63 changes: 63 additions & 0 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "cinn/hlir/framework/op_lowering_util.h"

#include "cinn/hlir/pe/nn_util.h"
#include "cinn/utils/string.h"
#ifdef CINN_WITH_CUDA
#include "cinn/common/bfloat16.h"
#include "cinn/common/float16.h"
Expand Down Expand Up @@ -1429,6 +1430,68 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
} while (--index >= 0);
}

bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
if (op_pattern_dict[node->op()] == framework::kReduction && op_pattern_dict[master->op()] == framework::kReduction &&
node != master) {
auto node_shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id());
auto master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());

VLOG(6) << "Checking CanFuseReduceByBlockSync";
VLOG(6) << "node->id() = " << node->id() << ", node_shape.size() = " << node_shape.size();
for (auto x : node_shape) {
VLOG(6) << x;
}
VLOG(6) << "master->id() = " << master->id() << ", master_shape.size() = " << master_shape.size();
for (auto x : master_shape) {
VLOG(6) << x;
}

static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
for (const std::string& op_type : reduce_op_type) {
// TODO: this may speed up not only reduce_xxx_split nodes, but we limit it to reduce_xxx_split nodes for accuracy
// safety
if (cinn::utils::Startswith(master->id(), op_type + "_split") &&
cinn::utils::Startswith(node->id(), op_type + "_split")) {
// Returns true only when shape is not equal. Shape equal is handled by other fusion
if (node_shape.size() != master_shape.size()) {
return true;
}
for (int i = 0; i < node_shape.size(); ++i) {
if (node_shape[i] != master_shape[i]) {
return true;
}
}
}
}
}
return false;
}

void SyncGpuBlocks(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
VLOG(6) << "Calling SyncGpuBlocks";
if (!group->output_nodes.count(node)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir_sch.SetBuffer(block, "local", true);
}
auto node_data = GetNodeData(node);
auto master_data = GetNodeData(master);
auto node_block = ir_sch.GetBlock(node->id());
auto master_block = ir_sch.GetBlock(master_data->id());
ir_sch.SyncGpuBlocks(master_block, node_block);
}

std::unordered_map<std::string, NodeData*> GetNodeDataSet(const std::unordered_set<Node*>& nodes_set) {
std::unordered_map<std::string, NodeData*> node_data_set;
for (auto node : nodes_set) {
Expand Down
14 changes: 14 additions & 0 deletions cinn/hlir/framework/op_lowering_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

void SyncGpuBlocks(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_inline,
Expand Down
19 changes: 19 additions & 0 deletions cinn/hlir/pass/fusion_merge_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,20 @@ CONDITION_FUNC(injective_horizontal_with_reduce) {
return elementwise_fuse_reduce(helper, first, second);
}

inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) {
static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
VLOG(6) << "Checking ReduceSplitCanFuse";
VLOG(6) << "producer->id() = " << producer->id();
VLOG(6) << "reducer->id() = " << reducer->id();
for (const std::string& op_type : reduce_op_type) {
if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) {
return true;
}
}
return false;
}

CONDITION_FUNC(reduce_fuse_broadcast) {
// if same shape with horizontal relation
if (is_same_size(helper, first, second)) {
Expand Down Expand Up @@ -388,6 +402,7 @@ CONDITION_FUNC(reduce_fuse_broadcast) {
}

CONDITION_FUNC(reduce_fuse_reduce) {
VLOG(6) << "In reduce_fuse_reduce";
if (!limit_args(helper, first, second)) {
return false;
}
Expand All @@ -409,6 +424,10 @@ CONDITION_FUNC(reduce_fuse_reduce) {
}
CHECK(reducer_1) << "Can't find reduce op in group " << second->group_id;

if (ReduceSplitCanFuse(reducer_0, reducer_1)) {
return true;
}

// check reduce has same input shape and output shape
auto reducer_0_input_shape = helper->shape_dict_.at(reducer_0->inlinks_in_order()[0]->source()->id());
auto reducer_0_output_shape = helper->shape_dict_.at(reducer_0->outlinks_in_order()[0]->sink()->id());
Expand Down
20 changes: 20 additions & 0 deletions cinn/hlir/pass/op_fusion_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,34 @@ CONDITION_FUNC(without_last_dimension_in_reduce) {
return helper->WithoutLastDimInReduce(in_shape, reduce_axes);
}

inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) {
static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
VLOG(6) << "Checking ReduceSplitCanFuse";
VLOG(6) << "producer->id() = " << producer->id();
VLOG(6) << "reducer->id() = " << reducer->id();
for (const std::string& op_type : reduce_op_type) {
if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) {
return true;
}
}
return false;
}

CONDITION_FUNC(reduce_fuse_reduce) {
VLOG(6) << "In reduce_fuse_reduce";
Node* reducer = NULL;
for (auto* master : consumer->master_nodes) {
if (helper->GetOpKind(master) == framework::kReduction) {
reducer = master;
break;
}
}

if (ReduceSplitCanFuse(producer, reducer)) {
return true;
}

// check reduce has same input shape and output shape
auto producer_input_shape = helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id());
auto producer_output_shape = helper->shape_dict_.at(producer->outlinks_in_order()[0]->sink()->id());
Expand Down
9 changes: 9 additions & 0 deletions cinn/ir/ir.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ void Not::Verify() const { CHECK_EQ(v().type(), type_of<bool>()); }

Type Not::type() const { return type_; }

Expr GetReference::Make(Expr v) {
auto node = make_shared<GetReference>(v);
return Expr(node);
}

void GetReference::Verify() const { CHECK(v().defined()); }

Type GetReference::type() const { return type_; }

Expr Let::Make(Expr symbol, Expr body) {
auto *n = make_shared<Let>();
CHECK(symbol.type().valid());
Expand Down
15 changes: 15 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,21 @@ struct Not : public UnaryOpNode<Not> {
static const IrNodeTy _node_type_ = IrNodeTy::Not;
};

/**
* Get reference, such as C++ &x
*
*/
struct GetReference : public UnaryOpNode<GetReference> {
explicit GetReference(Expr v) : UnaryOpNode<GetReference>(common::Int(32).set_cpp_reference(), v) {}

static Expr Make(Expr v);

Type type() const override;
void Verify() const override;

static const IrNodeTy _node_type_ = IrNodeTy::GetReference;
};

struct Let : public ExprNode<Let> {
Expr symbol;
Expr body;
Expand Down
1 change: 1 addition & 0 deletions cinn/ir/ir_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ScheduleBlockRealize;
#define NODETY_UNARY_OP_FOR_EACH(macro__) \
macro__(Minus) \
macro__(Not) \
macro__(GetReference) \

#define NODETY_OP_FOR_EACH(macro__) NODETY_BINARY_OP_FOR_EACH(macro__) NODETY_UNARY_OP_FOR_EACH(macro__)

Expand Down
4 changes: 4 additions & 0 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ void IrPrinter::Visit(const Minus *x) {
Print(x->v());
os_ << ")";
}
void IrPrinter::Visit(const GetReference *x) {
os_ << "&";
Print(x->v());
}
void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) {
os() << "parallel for (";
Expand Down
Loading