Skip to content

Commit

Permalink
Perf: split code_gen.cpp to reduce compilation time (#4210)
Browse files Browse the repository at this point in the history
* split code_gen.cpp to reduce compilation time

* alter the file location of code_gen_*.cu

* fix an error in CMakeList
  • Loading branch information
dzzz2001 authored May 25, 2024
1 parent f73a83d commit 25584b0
Show file tree
Hide file tree
Showing 20 changed files with 1,647 additions and 719 deletions.
12 changes: 11 additions & 1 deletion source/module_hamilt_lcao/module_gint/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,23 @@ list(APPEND objects
if(USE_CUDA)
list(APPEND objects
kernels/cuda/cuda_tools.cu
kernels/cuda/vbatch_matrix_mul.cu
kernels/cuda/gint_vl.cu
kernels/cuda/gint_rho.cu
kernels/cuda/gint_force.cu
gint_vl_gpu.cu
gint_rho_gpu.cu
gint_force_gpu.cu
kernels/cuda/gemm_selector.cu
kernels/cuda/code_gen_00.cu
kernels/cuda/code_gen_01.cu
kernels/cuda/code_gen_02.cu
kernels/cuda/code_gen_03.cu
kernels/cuda/code_gen_04.cu
kernels/cuda/code_gen_05.cu
kernels/cuda/code_gen_06.cu
kernels/cuda/code_gen_07.cu
kernels/cuda/code_gen_08.cu
kernels/cuda/code_gen_09.cu
gtask_vl.cpp
gtask_rho.cpp
gtask_force.cpp
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "kernels/cuda/cuda_tools.cuh"
#include "kernels/cuda/vbatch_matrix_mul.cuh"
#include "module_base/ylm.h"
#include "module_hamilt_lcao/module_gint/gint_rho.h"
#include "module_hamilt_lcao/module_gint/gint_tools.h"
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <omp.h>

#include "kernels/cuda/cuda_tools.cuh"
#include "kernels/cuda/vbatch_matrix_mul.cuh"
#include "module_base/ylm.h"
#include "module_hamilt_lcao/module_gint/gint_tools.h"
#include "module_hamilt_lcao/module_gint/gint_vl.h"
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/module_gint/grid_technique.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <cuda_runtime.h>

#include "kernels/cuda/cuda_tools.cuh"
#include "kernels/cuda/vbatch_matrix_mul.cuh"
#include "kernels/cuda/gemm_selector.cuh"
#endif

// Author: mohan
Expand Down
22 changes: 0 additions & 22 deletions source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Generate and test the efficiency of matrix multiplication functions with different parameters
// This file takes a long time to compile

gemm_time_measure<double, 2, 16, 16, 32, 2, 2, 16, 2, 16>(max_m,
max_n,
d_m,
Expand Down Expand Up @@ -4181,25 +4178,6 @@ gemm_time_measure<double, 16, 32, 64, 64, 32, 16, 32, 16, 32>(max_m,
h_global_C,
d_global_C);

gemm_time_measure<double, 16, 32, 64, 64, 32, 16, 32, 16, 32>(max_m,
max_n,
d_m,
d_n,
d_k,
d_global_A_array,
d_global_lda,
d_global_B_array,
d_global_ldb,
d_global_C_array,
d_global_ldc,
batchCount,
temp_stream,
fastest_time,
fastest_algo,
cpu_result,
h_global_C,
d_global_C);

gemm_time_measure<double, 20, 8, 40, 24, 20, 20, 8, 20, 8>(max_m,
max_n,
d_m,
Expand Down
473 changes: 473 additions & 0 deletions source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cuh

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_00.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "vbatch_matrix_mul.cuh"

template void gemm_time_measure<double, 2, 16, 16, 32, 2, 2, 16, 2, 16>(int, int, int*, int*, int*, double**, int*, double**, int*, double**, int*, int, cudaStream_t, float&, matrix_multiple_func_type&, double*, double*, double*);

template void gemm_time_measure<double, 2, 16, 16, 32, 4, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 2, 16, 16, 32, 6, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 2, 16, 16, 32, 8, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 2, 16, 16, 48, 2, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 2, 16, 16, 48, 4, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 2, 16, 16, 48, 6, 2, 16, 2, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 24, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 24, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 24, 12, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 32, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 32, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 40, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 40, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 48, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 56, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 8, 64, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 16, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 16, 12, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 24, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 24, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 32, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

48 changes: 48 additions & 0 deletions source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_01.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "vbatch_matrix_mul.cuh"

template void gemm_time_measure<double, 4, 8, 16, 32, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 40, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 48, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 16, 56, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 16, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 24, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 24, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 32, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 24, 40, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 32, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 32, 16, 8, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 32, 24, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 40, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 40, 24, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 48, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 8, 56, 16, 4, 4, 8, 4, 8>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 32, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 32, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 32, 12, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 32, 16, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 48, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 48, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

48 changes: 48 additions & 0 deletions source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_02.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "vbatch_matrix_mul.cuh"

template void gemm_time_measure<double, 4, 16, 16, 48, 12, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 64, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 16, 64, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 32, 32, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 32, 32, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 32, 32, 12, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 32, 48, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 32, 48, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 48, 32, 4, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 16, 48, 32, 8, 4, 16, 4, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 24, 24, 48, 4, 4, 24, 4, 24>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 24, 24, 48, 8, 4, 24, 4, 24>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 24, 24, 48, 12, 4, 24, 4, 24>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 24, 48, 48, 4, 4, 24, 4, 24>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 24, 48, 48, 8, 4, 24, 4, 24>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 32, 32, 64, 4, 4, 32, 4, 32>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 32, 32, 64, 8, 4, 32, 4, 32>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 32, 32, 64, 12, 4, 32, 4, 32>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 4, 32, 32, 64, 16, 4, 32, 4, 32>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 6, 16, 48, 32, 6, 6, 16, 6, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 6, 16, 48, 32, 12, 6, 16, 6, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 6, 16, 48, 48, 6, 6, 16, 6, 16>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

template void gemm_time_measure<double, 8, 4, 16, 12, 8, 8, 4, 8, 4>(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*);

Loading

0 comments on commit 25584b0

Please sign in to comment.