Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
tests(auto-unroll): Check whether the Auto-unroll is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Mar 27, 2023
1 parent 4eeb0d3 commit a4b16f8
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ endif()
cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper)
cc_test(test_multi_level_tiling SRCS multi_level_tiling_test.cc DEPS cinncore)
cc_test(test_skip_rule SRCS skip_rule_test.cc DEPS cinncore)
cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore)
cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore auto_gen_rule_test_helper test_program_builder)
77 changes: 75 additions & 2 deletions cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#include <glog/logging.h>
#include <gtest/gtest.h>

#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "cinn/cinn.h"
#include "cinn/lang/lower.h"
#include "tests/program_builder.h"
#include "tests/subgraph_program_builder.h"

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -73,7 +76,7 @@ TEST(AutoUnroll, UnrollableApply) {
auto* init_schedule_block = init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(6) << "Before auto-unroll:\n" << ast_expr;
VLOG(-6) << "Before auto-unroll:\n" << ast_expr;

AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
Expand All @@ -96,12 +99,82 @@ TEST(AutoUnroll, UnrollableApply) {
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128);
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front();
VLOG(-6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front();
};

test_func(&ir_schedule);
test_func(&states[0]->ir_schedule);
}

#ifdef CINN_WITH_CUDA
class TestAutoUnroll : public TestAutoGenRuleBase {
public:
std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"};
};
TEST_F(TestAutoUnroll, ApplyOnMatmulWithTiling) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 4}}, {"Y", {4, 32}}});
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];

AutoUnroll auto_unroll(target_);
SearchState state(ir_schedule, 0, {});
const std::string& applied_block_name = default_output_names.back();
EXPECT_EQ(auto_unroll.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_unroll.ApplyOnBlock(state, applied_block_name);
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);

// Check if the block has an 'auto_unroll_max_step' attribute
auto* applied_block_realize = exprs.front().As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block = applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_FALSE(applied_schedule_block->attrs.empty());
EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128);
VLOG(6) << "Expr after AutoUnroll applied on block:max_step=" << *max_step << ", Ast:\n" << exprs.front();

// build ir::Module and debug source code
auto build_module = BuildIRModule(new_states[0]->ir_schedule);
auto source_code = GenSourceCode(build_module);
VLOG(6) << " auto-schedule source code:\n" << source_code;
// execute and check precision
CheckResult(GenExecutableKernel(build_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
default_input_names,
default_output_names,
{{4, 4}, {4, 4}},
{{4, 4}},
target_);
}

TEST_F(TestAutoUnroll, PureSpatial) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = {
"var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"};
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}, {"y", input_shape}});

Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubSubGraphBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];

AutoUnroll auto_unroll(target_);
for (const auto& applied_block_name : output_names) {
EXPECT_EQ(auto_unroll.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply);
}
}
#endif

} // namespace auto_schedule
} // namespace cinn
16 changes: 8 additions & 8 deletions tests/subgraph_program_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BiasBnReLUSubGraphBuilder : public ProgramBuilder {
public:
BiasBnReLUSubGraphBuilder() : ProgramBuilder("bias_bn_relu_builder") {}
frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo, const utils::AttributeMap& attrs = {}) {
CHECK(inputs_varinfo.size()==4);
CHECK(inputs_varinfo.size() == 4);
auto conv_output = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
auto bias = builder_.CreateInput(inputs_varinfo[1].type, inputs_varinfo[1].shape, inputs_varinfo[1].id);
auto bn_scale = builder_.CreateInput(inputs_varinfo[2].type, inputs_varinfo[2].shape, inputs_varinfo[2].id);
Expand All @@ -41,7 +41,7 @@ class ExpTwoConsumersOpBuilder : public ProgramBuilder {
public:
ExpTwoConsumersOpBuilder() : ProgramBuilder("exp_two_consumers_builder") {}
frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo, const utils::AttributeMap& attrs = {}) {
CHECK(inputs_varinfo.size()==1);
CHECK(inputs_varinfo.size() == 1);
auto x = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
auto exp_x = builder_.Exp(x);
auto add_x = builder_.Add(exp_x, x);
Expand All @@ -54,11 +54,11 @@ class GatherAddSubSubGraphBuilder : public ProgramBuilder {
public:
GatherAddSubSubGraphBuilder() : ProgramBuilder("gather_add_sub_builder") {}
frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo, const utils::AttributeMap& attrs = {}) {
CHECK(inputs_varinfo.size()==2);
auto x = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
auto y = builder_.CreateInput(inputs_varinfo[1].type, inputs_varinfo[1].shape, inputs_varinfo[1].id);
CHECK(inputs_varinfo.size() == 2);
auto x = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
auto y = builder_.CreateInput(inputs_varinfo[1].type, inputs_varinfo[1].shape, inputs_varinfo[1].id);
auto input_x_shape = inputs_varinfo[0].shape;
auto where_x_0 = builder_.Gather(x, builder_.FillConstant({input_x_shape[0]}, 0, "constant_idx_first"));
auto where_x_0 = builder_.Gather(x, builder_.FillConstant({input_x_shape[0]}, 0, "constant_idx_first"));
auto where_x_last =
builder_.Gather(x, builder_.FillConstant({input_x_shape[0]}, input_x_shape[0] - 1, "constant_idx_last"));
auto add_1 = builder_.Add(where_x_0, y);
Expand All @@ -71,8 +71,8 @@ class FillConstantAddSubGraphBuilder : public ProgramBuilder {
public:
FillConstantAddSubGraphBuilder() : ProgramBuilder("fill_constant_add_builder") {}
frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo, const utils::AttributeMap& attrs = {}) {
CHECK(inputs_varinfo.size()==1);
auto x = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
CHECK(inputs_varinfo.size() == 1);
auto x = builder_.CreateInput(inputs_varinfo[0].type, inputs_varinfo[0].shape, inputs_varinfo[0].id);
auto fill_constant = builder_.FillConstant(inputs_varinfo[0].shape, 1.0f, "fill_constant");
builder_.Add(x, fill_constant);
return builder_.Build();
Expand Down

0 comments on commit a4b16f8

Please sign in to comment.