Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

A100 Speed Benchmark Temporary PR #1422

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions cinn/common/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/common/target.h"
#ifdef CINN_WITH_CUDA
#include <cuda_runtime_api.h>
#include <driver_types.h>
#endif

#include <glog/logging.h>

#include <sstream>

#include "cinn/common/target.h"
#include "cinn/runtime/cinn_runtime.h"

namespace cinn {
Expand Down Expand Up @@ -49,6 +52,24 @@ int Target::max_num_threads() const {
return 1024;
}

int Target::get_multi_processor_count() const {
CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get multi processor count";
int num_sm = 0;
#ifdef CINN_WITH_CUDA
cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0);
#endif
return num_sm;
}

int Target::get_max_threads_per_sm() const {
CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max threads per stream processor";
int max_thread = 0;
#ifdef CINN_WITH_CUDA
cudaDeviceGetAttribute(&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0);
#endif
return max_thread;
}

std::vector<Target::Lib> Target::get_target_libs() const { return libs; }

int Target::get_target_bits() const {
Expand Down
4 changes: 4 additions & 0 deletions cinn/common/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ struct Target {

int max_num_threads() const;

int get_multi_processor_count() const;

int get_max_threads_per_sm() const;

int get_target_bits() const;

std::vector<Lib> get_target_libs() const;
Expand Down
9 changes: 8 additions & 1 deletion cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,14 @@ Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const
return Reshape(x, new_shape);
}
}
return CustomInstr(op_type, {x}, {{"dim", dim}, {"keep_dim", keep_dim}}).front();
// Convert the negative dim to a positive number
std::vector<int> reduce_dim(dim.begin(), dim.end());
for (int i = 0; i < dim.size(); i++) {
if (reduce_dim[i] < 0) {
reduce_dim[i] = x->shape.size() + reduce_dim[i];
}
}
return CustomInstr(op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}).front();
}

#define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \
Expand Down
5 changes: 3 additions & 2 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1226,8 +1226,9 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}
}

auto masters = GetMasters(node, nodes_inline, nodes_set);
// node can be inline.
if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) {
if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir::ComputeInlineChecker checker(ir_sch, block);
if (!checker.Check()) {
Expand Down Expand Up @@ -1327,7 +1328,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}

VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group);
VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
}

Expand Down
69 changes: 69 additions & 0 deletions cinn/hlir/framework/op_lowering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,75 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) {
Compile(net_builder);
}

TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) {
int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count();
int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm();
int warp_reduce_threshold = sm_count * max_threads_per_sm / 32;
int h = warp_reduce_threshold - 10;
int w = 256;
NetBuilder net_builder("Block_Reduce_Fuse_Broadcast");
// create model
{
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.ReduceSum(A, {1}, true);
auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1});
}

Compile(net_builder);
}

TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) {
int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count();
int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm();
int warp_reduce_threshold = sm_count * max_threads_per_sm / 32;
int h = warp_reduce_threshold - 10;
int w = 256;
NetBuilder net_builder("Block_Reduce_Fuse_Elementwise");
// create model
{
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h}, "B");
auto C = net_builder.ReduceSum(A, {1}, true);
auto D = net_builder.Add(B, C);
}

Compile(net_builder);
}
TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) {
int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count();
int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm();
int warp_reduce_threshold = sm_count * max_threads_per_sm / 32;
int h = warp_reduce_threshold + 10;
int w = 256;
NetBuilder net_builder("Warp_Reduce_Fuse_Broadcast");
// create model
{
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.ReduceSum(A, {1}, true);
auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1});
}

Compile(net_builder);
}

TEST(OpFusionPass, Warp_Reduce_Fuse_Elementwise) {
int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count();
int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm();
int warp_reduce_threshold = sm_count * max_threads_per_sm / 32;
int h = warp_reduce_threshold + 10;
int w = 256;
NetBuilder net_builder("Warp_Reduce_Fuse_Elementwise");
// create model
{
auto A = net_builder.CreateInput(Float(32), {h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h}, "B");
auto C = net_builder.ReduceSum(A, {1}, true);
auto D = net_builder.Add(B, C);
}

Compile(net_builder);
}

} // namespace framework
} // namespace hlir
} // namespace cinn
95 changes: 59 additions & 36 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,10 @@ bool IsConstOp(const framework::Node* node) {
}

std::vector<int> GetInputShape(const Node* node, const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
auto producers = GetProducers(node);
CHECK(producers.size());
auto input_data = GetInputNodeData(node);
CHECK(input_data.size());

auto producer_data = GetNodeData(producers.front());
return shape_dict.at(producer_data->id());
return shape_dict.at(input_data.front()->id());
}

std::vector<int> GetOutputShape(const Node* node, const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
Expand Down Expand Up @@ -577,10 +576,25 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch,
const std::vector<int>& inshape,
const std::vector<int>& axes,
const common::Target& target) {
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < inshape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
need_reduce_last_count *= inshape[i];
}
}
int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / float(target.get_max_threads_per_sm()));
// Set Num_max_threads to 32 is Warp Reduce
if (target.get_multi_processor_count() < warp_reduce_need_sm_count) {
max_num_threads = 32;
}
// find first reduce and second reduce axis.
int lane = 1;
int index = static_cast<int>(axes.size()) - 1;
auto max_num_threads = target.max_num_threads();
int lane = 1;
int index = static_cast<int>(axes.size()) - 1;

for (; index >= 0; --index) {
if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) {
break;
Expand Down Expand Up @@ -639,7 +653,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch,
bool CanbeInline(Node* node,
const std::vector<Node*> consumers,
const Node* reducer,
const Node* laster,
const std::unordered_set<Node*> masters,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
Expand Down Expand Up @@ -681,10 +695,14 @@ bool CanbeInline(Node* node,
return false;
} else {
auto node_shape = GetOutputShape(node, shape_dict);
auto last_shape = GetOutputShape(laster, shape_dict);
if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>()) !=
std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies<int>())) {
return true;
auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());

for (auto master : masters) {
auto master_shape = GetOutputShape(master, shape_dict);
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());
if (node_size != master_size) {
return true;
}
}

return false;
Expand Down Expand Up @@ -1316,7 +1334,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
if (!group->output_nodes.count(node)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir_sch.SetBuffer(block, "local", true);
ir_sch.SetBuffer(block, "local");
}

if (op_pattern_dict[node->op()] == framework::kReduction) {
Expand Down Expand Up @@ -1373,11 +1391,14 @@ std::unordered_map<std::string, NodeData*> GetNodeDataSet(const std::unordered_s
return node_data_set;
}

Node* GetMaster(Node* node, const std::unordered_set<Node*>& nodes_inline, const std::unordered_set<Node*>& nodes_set) {
std::unordered_set<Node*> GetMasters(Node* node,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set) {
// find consumer
std::unordered_set<Node*> visited;
std::queue<Node*> candidates;
candidates.push(node);
std::unordered_set<Node*> masters;

while (!candidates.empty()) {
auto candidate = candidates.front();
Expand All @@ -1392,19 +1413,20 @@ Node* GetMaster(Node* node, const std::unordered_set<Node*>& nodes_inline, const
candidates.push(consumer);
visited.insert(consumer);
} else {
return consumer;
masters.insert(consumer);
}
}
}

return nullptr;
return masters;
}

void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group) {
auto exprs_inorder = ir_sch.GetAllBlocks();
auto node_data_set = GetNodeDataSet(nodes_set);
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
Expand Down Expand Up @@ -1441,34 +1463,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
auto node = node_data->source_node.get();
auto node_shape = shape_dict.at(node_data->id());

auto master = GetMaster(node, nodes_inline, nodes_set);
if (!master) {
auto masters = GetMasters(node, nodes_inline, nodes_set);
if (masters.empty()) {
continue;
}

auto master_data = GetNodeData(master);
auto master_shape = shape_dict.at(master_data->id());
if (op_pattern_dict[master->op()] == framework::kReduction) {
master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());
}
bool do_set_buffer_to_shared = false;
for (auto master : masters) {
auto master_data = GetNodeData(master);
auto master_shape = shape_dict.at(master_data->id());
if (op_pattern_dict[master->op()] == framework::kReduction) {
master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());
}

auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());
auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());

if (node_size == master_size) {
continue;
if (node_size != master_size) {
if (check_sync_mark(idx, master_data->id())) {
auto loops = ir_sch.GetLoops(master_data->id());
ir_sch.SyncThreads(loops.back(), false);
sync_mark.insert(master_data->id());
}
do_set_buffer_to_shared = true;
}
}

{
if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) {
auto block = ir_sch.GetBlock(node_data->id());
ir_sch.SetBuffer(block, "shared", true);
}

if (check_sync_mark(idx, master_data->id())) {
auto loops = ir_sch.GetLoops(master_data->id());
ir_sch.SyncThreads(loops.back(), false);
sync_mark.insert(master_data->id());
}
}
}

Expand Down
9 changes: 7 additions & 2 deletions cinn/hlir/framework/op_lowering_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set<Node*>& node
bool CanbeInline(Node* node,
const std::vector<Node*> consumers,
const Node* reducer,
const Node* laster,
const std::unordered_set<Node*> masters,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict);
Expand All @@ -72,6 +72,10 @@ Node* GetMasterToComputeAt(Node* node,
const std::unordered_map<Node*, Node*>& virtual_consumers,
const absl::flat_hash_map<std::string, shape_t>& shape_dict);

std::unordered_set<Node*> GetMasters(Node* node,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set);

void LoopAssignReduce(ir::IRSchedule& ir_sch,
const Node* node,
const Node* reducer,
Expand All @@ -90,7 +94,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group);

} // namespace framework
} // namespace hlir
Expand Down
Loading