Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

feat(compile): add nvcc compiler #1441

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
6 changes: 2 additions & 4 deletions cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "cinn/backends/nvrtc/nvrtc_util.h"
#include "cinn/runtime/cuda/cuda_module.h"
#include "cinn/runtime/cuda/cuda_util.h"
#include "cinn/runtime/flags.h"
#endif

DECLARE_string(cinn_source_code_save_path);
Expand Down Expand Up @@ -123,16 +124,13 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)
SourceCodePrint::GetInstance()->write(source_code);
using runtime::cuda::CUDAModule;

backends::nvrtc::Compiler compiler;

nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << source_code;

cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));

RuntimeSymbols symbols;

for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
Expand Down
94 changes: 94 additions & 0 deletions cinn/backends/nvrtc/nvrtc_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,30 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <nvrtc.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include <fstream>
#include <iostream>

#include "cinn/backends/cuda_util.h"
#include "cinn/backends/nvrtc/header_generator.h"
#include "cinn/common/common.h"
#include "cinn/runtime/flags.h"
#include "cinn/utils/string.h"

DECLARE_string(cinn_nvcc_cmd_path);
DECLARE_bool(nvrtc_compile_to_cubin);

namespace cinn {
namespace backends {
namespace nvrtc {

std::string Compiler::operator()(const std::string& code, bool include_headers) {
if (runtime::CanUseNvccCompiler()) {
return CompileWithNvcc(code);
}
return CompileCudaSource(code, include_headers);
}

Expand Down Expand Up @@ -140,6 +151,89 @@ std::string Compiler::CompileCudaSource(const std::string& code, bool include_he
return data;
}

std::string Compiler::CompileWithNvcc(const std::string& cuda_c) {
// read dir source
std::string dir = "./source";
if (access(dir.c_str(), 0) == -1) {
CHECK(mkdir(dir.c_str(), 7) != -1) << "Fail to mkdir " << dir;
}

// get unqiue prefix name
prefix_name_ = dir + "/" + common::UniqName("rtc_tmp");

auto cuda_c_file = prefix_name_ + ".cu";
std::ofstream ofs(cuda_c_file, std::ios::out);
CHECK(ofs.is_open()) << "Fail to open file " << cuda_c_file;
ofs << cuda_c;
ofs.close();

CompileToPtx();
CompileToCubin();

return prefix_name_ + ".cubin";
}

// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", std::ios::in); }

void Compiler::CompileToPtx() {
auto include_dir = common::Context::Global().runtime_include_dir();
std::string include_dir_str = "";
for (auto dir : include_dir) {
if (include_dir_str.empty()) {
include_dir_str = dir;
} else {
include_dir_str += ":" + dir;
}
}

std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path +
std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + include_dir_str;
options += " -arch=" + GetDeviceArch();
options += " -o " + prefix_name_ + ".ptx";
options += " " + prefix_name_ + ".cu";

VLOG(2) << "Nvcc Compile Options : " << options;
CHECK(system(options.c_str()) == 0) << options;
}

void Compiler::CompileToCubin() {
std::string options =
std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + std::string(":$PATH && nvcc --cubin -O3");
options += " -arch=" + GetDeviceArch();
options += " -o " + prefix_name_ + ".cubin";
options += " " + prefix_name_ + ".ptx";

VLOG(2) << "Nvcc Compile Options : " << options;
CHECK(system(options.c_str()) == 0) << options;
}

std::string Compiler::GetDeviceArch() {
int major = 0, minor = 0;
if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == cudaSuccess &&
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == cudaSuccess) {
return "sm_" + std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
return "sm_30";
}
}

std::string Compiler::ReadFile(const std::string& file_name, std::ios_base::openmode mode) {
// open cubin file
std::ifstream ifs(file_name, mode);
CHECK(ifs.is_open()) << "Fail to open file " << file_name;
ifs.seekg(std::ios::end);
auto len = ifs.tellg();
ifs.seekg(0);

// read cubin file
std::string file_data(len, ' ');
ifs.read(&file_data[0], len);
ifs.close();
return std::move(file_data);
}

} // namespace nvrtc
} // namespace backends
} // namespace cinn
13 changes: 13 additions & 0 deletions cinn/backends/nvrtc/nvrtc_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ class Compiler {
* whether to compile the source code into cubin, only works with cuda version > 11.1
*/
bool compile_to_cubin_{false};

// compile with nvcc
std::string CompileWithNvcc(const std::string&);

// compile to ptx
void CompileToPtx();
// compile to cubin
void CompileToCubin();
std::string GetDeviceArch();

std::string ReadFile(const std::string&, std::ios_base::openmode);

std::string prefix_name_{""};
};

} // namespace nvrtc
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "cinn/common/context.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/ir/module.h"
#include "cinn/runtime/flags.h"

DECLARE_int32(cinn_parallel_compile_size);
DECLARE_int32(cinn_parallel_compile_thread);
Expand Down Expand Up @@ -178,10 +179,9 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
graph->SavePTXCode(ptx);

// load cumodule
cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));

// register kernel
backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) {
Expand Down
24 changes: 12 additions & 12 deletions cinn/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "cinn/backends/cuda_util.h"
#include "cinn/runtime/cuda/cuda_util.h"
#include "cinn/runtime/flags.h"
#include "cinn/utils/profiler.h"

namespace cinn {
Expand Down Expand Up @@ -106,16 +107,11 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name)
jit_options[4] = CU_JIT_GENERATE_LINE_INFO;
jit_opt_vals[4] = reinterpret_cast<void*>(value);

CUresult status = cuModuleLoadDataEx(
&module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data());

if (CUDA_SUCCESS != status) {
RAW_LOG(ERROR, "PTX JIT ERROR LOG: %s\n.", log_buffer.data());
const char* name;
cuGetErrorName(status, &name);
const char* msg;
cuGetErrorString(status, &msg);
RAW_LOG(FATAL, "The error `%s` occurs while compiling the ptx! And its message is `%s`.", name, msg);
if (runtime::CanUseNvccCompiler()) {
CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str()));
} else {
CUDA_DRIVER_CALL(cuModuleLoadDataEx(
&module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data()));
}
}

Expand All @@ -127,11 +123,15 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name)
CUdeviceptr CUDAModule::GetGlobal(int device_id, const std::string& name, size_t nbytes) {
if (!module_per_card_[device_id]) {
std::lock_guard<std::mutex> lock(mutex_);
CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str()));
if (runtime::CanUseNvccCompiler()) {
CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str()));
} else {
CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str()));
}
}

CUdeviceptr global;
size_t _nbytes;
CUdeviceptr global;
CUDA_DRIVER_CALL(cuModuleGetGlobal(&global, &_nbytes, module_per_card_[device_id], name.c_str()));
return global;
}
Expand Down
17 changes: 16 additions & 1 deletion cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#include <gflags/gflags.h>
#include <glog/logging.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include <unordered_set>

Expand All @@ -35,6 +38,9 @@ using ::GFLAGS_NAMESPACE::Int64FromEnv;
using ::GFLAGS_NAMESPACE::StringFromEnv;

DEFINE_string(cinn_x86_builtin_code_root, StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), "");
DEFINE_string(cinn_nvcc_cmd_path,
StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"),
"Setting nvcc default path!");

DEFINE_int32(cinn_parallel_compile_size,
Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16),
Expand Down Expand Up @@ -82,9 +88,13 @@ DEFINE_bool(cinn_use_dense_merge_pass,
"Whether use dense merge pass.");

DEFINE_bool(nvrtc_compile_to_cubin,
BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false),
BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", true),
"Whether nvrtc compile cuda source into cubin instead of ptx (only works after cuda-11.1).");

DEFINE_bool(cinn_compile_with_nvrtc,
BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true),
"Whether nvrtc compile cuda source with nvrtc(default nvcc).");

// FLAGS for performance analysis and accuracy debug
DEFINE_bool(cinn_sync_run,
BoolFromEnv("FLAGS_cinn_sync_run", false),
Expand Down Expand Up @@ -180,6 +190,11 @@ unsigned long long RandomSeed::Clear() {
return old_seed;
}

bool CanUseNvccCompiler() {
std::string nvcc_dir = FLAGS_cinn_nvcc_cmd_path + "/nvcc";
return (access(nvcc_dir.c_str(), 0) == -1 ? false : true) && (!FLAGS_cinn_compile_with_nvrtc);
}

bool IsCompiledWithCUDA() {
#if !defined(CINN_WITH_CUDA)
return false;
Expand Down
2 changes: 2 additions & 0 deletions cinn/runtime/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ bool CheckStringFlagFalse(const std::string &flag);
void SetCinnCudnnDeterministic(bool state);
bool GetCinnCudnnDeterministic();

bool CanUseNvccCompiler();

class RandomSeed {
public:
static unsigned long long GetOrSet(unsigned long long seed = 0);
Expand Down