Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

upgrade op fusion lowering #1216

Open
wants to merge 36 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f31bea4
develop done
SunNy820828449 Feb 22, 2023
d7256cd
compile pass
SunNy820828449 Feb 22, 2023
571924e
fix test code
SunNy820828449 Feb 23, 2023
a227edf
update
SunNy820828449 Feb 24, 2023
c7618dd
update
SunNy820828449 Feb 27, 2023
5a8189a
update
SunNy820828449 Feb 28, 2023
080b3f5
update
SunNy820828449 Feb 28, 2023
d2f4614
update
SunNy820828449 Feb 28, 2023
0321a63
update
SunNy820828449 Mar 2, 2023
2a88930
update
SunNy820828449 Mar 2, 2023
9d6c2a8
update
SunNy820828449 Mar 2, 2023
73060e1
fix conflict
SunNy820828449 Mar 2, 2023
62373b0
update
SunNy820828449 Mar 2, 2023
644f922
update
SunNy820828449 Mar 2, 2023
82d8636
update
SunNy820828449 Mar 2, 2023
39a1938
update
SunNy820828449 Mar 2, 2023
5a2ddfc
update
SunNy820828449 Mar 2, 2023
fdb6048
update
SunNy820828449 Mar 2, 2023
dfffca8
fix
SunNy820828449 Mar 2, 2023
09e5d3d
update
SunNy820828449 Mar 3, 2023
29e02a5
update
SunNy820828449 Mar 3, 2023
a386985
fix
SunNy820828449 Mar 3, 2023
da38edb
fix fusion
SunNy820828449 Mar 3, 2023
881cba9
fix
SunNy820828449 Mar 6, 2023
e6eba07
fix cas
SunNy820828449 Mar 6, 2023
dc47c61
fix cuda
SunNy820828449 Mar 7, 2023
4372080
fix review
SunNy820828449 Mar 8, 2023
c408bc5
Merge branch 'develop' of https://github.com/PaddlePaddle/CINN into u…
SunNy820828449 Mar 9, 2023
ab14a30
support reduce + broadcast
SunNy820828449 Mar 13, 2023
18f90a2
fix
SunNy820828449 Mar 13, 2023
2d7d86a
merge develop
SunNy820828449 Mar 14, 2023
bf7dfb1
fix lowering order
SunNy820828449 Mar 14, 2023
7fa2a99
fix ci test
SunNy820828449 Mar 20, 2023
513cdb8
fix op lowering with output
SunNy820828449 Mar 20, 2023
511e3a4
fix reduce schedule
SunNy820828449 Mar 21, 2023
0c39822
fix get master
SunNy820828449 Mar 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)
using runtime::cuda::CUDAModule;

backends::nvrtc::Compiler compiler;

auto ptx = compiler(source_code);
CHECK(!ptx.empty());

Expand Down
22 changes: 11 additions & 11 deletions cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2016,17 +2016,17 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
};

{
std::vector<Expr> a_args, b_args;
if (ap)
a_args = ap->operands();
else
a_args.push_back(a);
if (bp)
b_args = bp->operands();
else
b_args.push_back(b);

return reduce_product_div_product(a_args, b_args);
// TODO: fix this
// std::vector<Expr> a_args, b_args;
// if (ap)
// a_args = ap->operands();
// else
// a_args.push_back(a);
// if (bp)
// b_args = bp->operands();
// else
// b_args.push_back(b);
// return reduce_product_div_product(a_args, b_args);
}

// x / x
Expand Down
2 changes: 1 addition & 1 deletion cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ TEST(CAS, SimplifyMinMax) {
LOG(INFO) << "p0 " << p0;
auto p2 = CasSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))");
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, (x / 2))");
}
{ // -(cinn_min(16, 3400-x-1)-1)/2 + x
Var x = ir::_Var_::Make("x", Int(32));
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gather_srcs(cinnapi_src SRCS
op_lowering.cc
accuracy_checker.cc
visualize_helper.cc
op_lowering_util.cc
)

if(WITH_CUDA)
Expand Down
8 changes: 8 additions & 0 deletions cinn/hlir/framework/graph.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ class Graph : public cinn::common::Graph {
}
}

std::unordered_set<Node*> NodeSet() {
std::unordered_set<Node*> node_set;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这函数功能完全和CollectNodes重复了而且也不常用啊。。。在用的地方直接定义

const auto& nodes = group->CollectNodes();
std::unordered_set<Node*> node_set(nodes.begin(), nodes.end());

不好么。。。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用的地方比较多 这里实现比较方便

for (auto node : CollectNodes()) {
node_set.insert(node);
}
return node_set;
}

std::unordered_set<NodeData*> GetInputNodeDatas();
std::unordered_set<NodeData*> GetOutputNodeDatas();

Expand Down
Loading