Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
tests(auto-unroll): fix code according to merge
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Mar 30, 2023
1 parent 3e983a1 commit 0240813
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "cinn/cinn.h"
#include "cinn/lang/lower.h"
#include "tests/concrete_program_builder.h"
#include "tests/program_builder.h"
#include "tests/subgraph_program_builder.h"

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -76,7 +76,7 @@ TEST(AutoUnroll, UnrollableApply) {
auto* init_schedule_block = init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(-6) << "Before auto-unroll:\n" << ast_expr;
VLOG(6) << "Before auto-unroll:\n" << ast_expr;

AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
Expand All @@ -99,7 +99,7 @@ TEST(AutoUnroll, UnrollableApply) {
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128);
VLOG(-6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front();
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front();
};

test_func(&ir_schedule);
Expand All @@ -112,14 +112,30 @@ class TestAutoUnroll : public TestAutoGenRuleBase {
std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"};
};
TEST_F(TestAutoUnroll, ApplyOnMatmulWithTiling) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 4}}, {"Y", {4, 32}}});

/* Before AutoUnroll:
* for (i=0; i < 4; i++):
* for(j=0; j < 4; j++):
* for(k=0; k < 4; k++):
* C(i, j) = C(i, j) + A(i, k) * B(k, j)
*
* After AutoUnroll on 'k', the third loop is unrolled.
* for(i=0; i < 4; i++):
* for(j=0; j < 4; j++):
* C(i, j) = C(i, j) + A(i, 0) * B(0, j)
* C(i, j) = C(i, j) + A(i, 1) * B(1, j)
* C(i, j) = C(i, j) + A(i, 2) * B(2, j)
* C(i, j) = C(i, j) + A(i, 3) * B(3, j)
*/
TEST_F(TestAutoUnroll, ApplyOnMatmulWithUnroll) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {4, 4}}, {"Y", {4, 4}}});
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];

// Construct the computation graph and convert it to ir::Expr
AutoUnroll auto_unroll(target_);
SearchState state(ir_schedule, 0, {});
const std::string& applied_block_name = default_output_names.back();
Expand All @@ -145,14 +161,15 @@ TEST_F(TestAutoUnroll, ApplyOnMatmulWithTiling) {
VLOG(6) << " auto-schedule source code:\n" << source_code;
// execute and check precision
CheckResult(GenExecutableKernel(build_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, -1, true))),
default_input_names,
default_output_names,
{{4, 4}, {4, 4}},
{{4, 4}},
target_);
}

/* Operators of type elementwise or injective can not be auto-unrolled.*/
TEST_F(TestAutoUnroll, PureSpatial) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
Expand All @@ -162,13 +179,15 @@ TEST_F(TestAutoUnroll, PureSpatial) {
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}, {"y", input_shape}});

// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubSubGraphBuilder().Build(inputs_varinfo));
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];

// Analyzes whether the block can be unrolled
AutoUnroll auto_unroll(target_);
for (const auto& applied_block_name : output_names) {
EXPECT_EQ(auto_unroll.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply);
Expand Down

0 comments on commit 0240813

Please sign in to comment.