diff --git a/docs/advanced/input_files/input-main.md b/docs/advanced/input_files/input-main.md index 6e58330990..9e87b52dcc 100644 --- a/docs/advanced/input_files/input-main.md +++ b/docs/advanced/input_files/input-main.md @@ -343,6 +343,8 @@ - [td\_heavi\_amp](#td_heavi_amp) - [out\_dipole](#out_dipole) - [out\_efield](#out_efield) + - [out\_vecpot](#out_vecpot) + - [init\_vecpot\_file](#init_vecpot_file) - [ocp](#ocp) - [ocp\_set](#ocp_set) - [Variables useful for debugging](#variables-useful-for-debugging) @@ -1836,6 +1838,14 @@ Warning: this function is not robust enough for the current version. Please try - **Note**: A trained, traced model file is needed. - **Default**: False +### deepks_equiv + +- **Type**: Boolean +- **Availability**: numerical atomic orbital basis +- **Description**: whether to use equivariant version of DeePKS +- **Note**: the equivariant version of DeePKS-kit is still under development, so this feature is currently only intended for internal usage. +- **Default**: False + ### deepks_model - **Type**: String @@ -3254,6 +3264,22 @@ These variables are used to control berry phase and wannier90 interface paramete - False: do not output efield. - **Default**: False +### out_vecpot + +- **Type**: Boolean +- **Description**: output TDDFT Vector potential or not(a.u.) + - True: output Vector potential in file "OUT.suffix/At.dat" + - False: do not output Vector potential. +- **Default**: False + +### init_vecpot_file + +- **Type**: Boolean +- **Description**: Init vector potential through file or not + - True: init vector potential from file "At.dat".(a.u.) It consists of four columns, representing istep and vector potential on each direction. + - False: calculate vector potential by integral of Efield +- **Default**: False + ### ocp - **Type**: Boolean diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-nscf b/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-nscf index b2c769166d..1dbecf0753 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-nscf +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-nscf @@ -8,9 +8,11 @@ nbands 12 calculation nscf scf_nmax 50 pw_diag_thr 1.0e-12 -scf_thr 1.0e-15 +scf_thr 1.0e-13 init_chg file symmetry 0 towannier90 1 nnkpfile diamond.nnkp -basis_type lcao +basis_type lcao +wannier_method 2 +out_wannier_unk 0 diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-scf b/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-scf index 9aefe7c367..89c777e981 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-scf +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao/INPUT-scf @@ -7,4 +7,4 @@ ecutwfc 50 calculation scf scf_thr 1e-13 out_chg 1 -basis_type lcao +basis_type lcao diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao/diamond.win b/examples/interface_wannier90/ABACUS_towannier90_lcao/diamond.win index 675745083d..816324eb2e 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao/diamond.win +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao/diamond.win @@ -1,15 +1,17 @@ num_bands = 12 num_wann = 8 -Begin Projections -C:s;px;py;pz -End Projections +dis_win_min = -7.2 +dis_win_max = 45 +dis_froz_min = -7.2 +dis_froz_max = 20 +dis_num_iter = 200 -dis_num_iter = 5000 -num_iter = 5000 -num_print_cycles = 50 +begin projections +C:s;px;py;pz +end projections -wannier_plot=.true. +wannier_plot=.false. wannier_plot_supercell = 3 wvfn_formatted = .true. @@ -26,6 +28,14 @@ begin unit_cell_cart -1.613990 1.613990 0.000000 end unit_cell_cart +bands_plot = true + +begin kpoint_path +G 0.0000000000 0.0000000000 0.0000000000 L 0.5000000000 0.5000000000 0.5000000000 +L 0.5000000000 0.5000000000 0.5000000000 W 0.5000000000 0.2500000000 0.7500000000 +W 0.5000000000 0.2500000000 0.7500000000 X 0.5000000000 0.0000000000 0.5000000000 +end kpoint_path + mp_grid : 4 4 4 begin kpoints diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-nscf b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-nscf index bd326e7944..f39b37ab96 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-nscf +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-nscf @@ -8,11 +8,12 @@ nbands 12 calculation nscf scf_nmax 50 pw_diag_thr 1.0e-12 -scf_thr 1.0e-15 +scf_thr 1.0e-13 init_chg file symmetry 0 towannier90 1 -wannier_method 1 nnkpfile diamond.nnkp -basis_type lcao +basis_type lcao +wannier_method 1 +out_wannier_unk 0 diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-scf b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-scf index 91d87fa7e0..50eb2f10c0 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-scf +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/INPUT-scf @@ -7,5 +7,5 @@ ecutwfc 50 calculation scf scf_thr 1e-13 out_chg 1 -basis_type lcao_in_pw +basis_type lcao_in_pw ks_solver lapack diff --git a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/diamond.win b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/diamond.win index 675745083d..816324eb2e 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/diamond.win +++ b/examples/interface_wannier90/ABACUS_towannier90_lcao_in_pw/diamond.win @@ -1,15 +1,17 @@ num_bands = 12 num_wann = 8 -Begin Projections -C:s;px;py;pz -End Projections +dis_win_min = -7.2 +dis_win_max = 45 +dis_froz_min = -7.2 +dis_froz_max = 20 +dis_num_iter = 200 -dis_num_iter = 5000 -num_iter = 5000 -num_print_cycles = 50 +begin projections +C:s;px;py;pz +end projections -wannier_plot=.true. +wannier_plot=.false. wannier_plot_supercell = 3 wvfn_formatted = .true. @@ -26,6 +28,14 @@ begin unit_cell_cart -1.613990 1.613990 0.000000 end unit_cell_cart +bands_plot = true + +begin kpoint_path +G 0.0000000000 0.0000000000 0.0000000000 L 0.5000000000 0.5000000000 0.5000000000 +L 0.5000000000 0.5000000000 0.5000000000 W 0.5000000000 0.2500000000 0.7500000000 +W 0.5000000000 0.2500000000 0.7500000000 X 0.5000000000 0.0000000000 0.5000000000 +end kpoint_path + mp_grid : 4 4 4 begin kpoints diff --git a/examples/interface_wannier90/ABACUS_towannier90_pw/INPUT-nscf b/examples/interface_wannier90/ABACUS_towannier90_pw/INPUT-nscf index 9010e1527f..d58eb06007 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_pw/INPUT-nscf +++ b/examples/interface_wannier90/ABACUS_towannier90_pw/INPUT-nscf @@ -5,11 +5,14 @@ orbital_dir ../../../tests/PP_ORB ntype 1 ecutwfc 50 nbands 4 +smearing_method fixed calculation nscf scf_nmax 50 pw_diag_thr 1.0e-12 -scf_thr 1.0e-15 +scf_thr 1.0e-13 init_chg file symmetry 0 towannier90 1 nnkpfile diamond.nnkp +basis_type pw +out_wannier_unk 0 \ No newline at end of file diff --git a/examples/interface_wannier90/ABACUS_towannier90_pw/diamond.win b/examples/interface_wannier90/ABACUS_towannier90_pw/diamond.win index 321f3058e0..69225b0213 100644 --- a/examples/interface_wannier90/ABACUS_towannier90_pw/diamond.win +++ b/examples/interface_wannier90/ABACUS_towannier90_pw/diamond.win @@ -1,7 +1,7 @@ num_wann = 4 num_iter = 20 -wannier_plot=.true. +wannier_plot=.false. wannier_plot_supercell = 3 wvfn_formatted = .true. diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 6209241fa5..3415ccb7b4 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -226,7 +226,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\ io_npz.o\ OBJS_GINT=gint.o\ - gint_gamma.o\ gint_gamma_env.o\ gint_gamma_vl.o\ gint_fvl.o\ @@ -298,7 +297,7 @@ OBJS_HSOLVER=diago_cg.o\ dngvd_op.o\ OBJS_HSOLVER_LCAO=hsolver_lcao.o\ - diago_blas.o\ + diago_scalapack.o\ diago_elpa.o\ elpa_new.o\ elpa_new_real.o\ diff --git a/source/module_base/global_variable.cpp b/source/module_base/global_variable.cpp index 9c27000423..250739b6de 100644 --- a/source/module_base/global_variable.cpp +++ b/source/module_base/global_variable.cpp @@ -207,6 +207,8 @@ bool deepks_scf = false; // caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16 bool deepks_bandgap = false; // for bandgap label. QO added 2021-12-15 bool deepks_out_unittest = false; +bool deepks_equiv = false; + bool deepks_setorb = false; bool out_element_info = false; // added by zhengdy 2021-11-26 diff --git a/source/module_base/global_variable.h b/source/module_base/global_variable.h index b9168664f1..06f2cf9771 100644 --- a/source/module_base/global_variable.h +++ b/source/module_base/global_variable.h @@ -231,6 +231,8 @@ extern bool deepks_scf; //(need libnpy and libtorch) if set 1, a trained model would be needed to cal V_delta and F_delta extern bool deepks_bandgap; // for bandgap label. QO added 2021-12-15 +extern bool deepks_equiv; //whether to use equviariant version of DeePKS + extern bool deepks_setorb; extern bool deepks_out_unittest; // if set 1, prints intermediate quantities that shall be used for making unit test diff --git a/source/module_base/test/timer_test.cpp b/source/module_base/test/timer_test.cpp index f90454ea16..59c206b110 100644 --- a/source/module_base/test/timer_test.cpp +++ b/source/module_base/test/timer_test.cpp @@ -118,10 +118,10 @@ TEST_F(TimerTest, PrintAll) EXPECT_THAT(output,testing::HasSubstr("TIME STATISTICS")); EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME")); EXPECT_THAT(output,testing::HasSubstr("NAME")); - EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)")); + EXPECT_THAT(output,testing::HasSubstr("TIME/s")); EXPECT_THAT(output,testing::HasSubstr("CALLS")); - EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)")); - EXPECT_THAT(output,testing::HasSubstr("PER(%)")); + EXPECT_THAT(output,testing::HasSubstr("AVG/s")); + EXPECT_THAT(output,testing::HasSubstr("PER/%")); // check output in file ifs.open("tmp"); std::cout << "Capture contents line by line from output file: \n" << std::endl; @@ -131,10 +131,10 @@ TEST_F(TimerTest, PrintAll) getline(ifs,output); EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME")); EXPECT_THAT(output,testing::HasSubstr("NAME")); - EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)")); + EXPECT_THAT(output,testing::HasSubstr("TIME/s")); EXPECT_THAT(output,testing::HasSubstr("CALLS")); - EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)")); - EXPECT_THAT(output,testing::HasSubstr("PER(%)")); + EXPECT_THAT(output,testing::HasSubstr("AVG/s")); + EXPECT_THAT(output,testing::HasSubstr("PER/%")); ifs.close(); remove("time.json"); } @@ -162,10 +162,10 @@ TEST_F(TimerTest, Finish) EXPECT_THAT(output,testing::HasSubstr("TIME STATISTICS")); EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME")); EXPECT_THAT(output,testing::HasSubstr("NAME")); - EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)")); + EXPECT_THAT(output,testing::HasSubstr("TIME/s")); EXPECT_THAT(output,testing::HasSubstr("CALLS")); - EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)")); - EXPECT_THAT(output,testing::HasSubstr("PER(%)")); + EXPECT_THAT(output,testing::HasSubstr("AVG/s")); + EXPECT_THAT(output,testing::HasSubstr("PER/%")); // check output in file ifs.open("tmp"); std::cout << "Capture contents line by line from output file: \n" << std::endl; @@ -175,10 +175,10 @@ TEST_F(TimerTest, Finish) getline(ifs,output); EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME")); EXPECT_THAT(output,testing::HasSubstr("NAME")); - EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)")); + EXPECT_THAT(output,testing::HasSubstr("TIME/s")); EXPECT_THAT(output,testing::HasSubstr("CALLS")); - EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)")); - EXPECT_THAT(output,testing::HasSubstr("PER(%)")); + EXPECT_THAT(output,testing::HasSubstr("AVG/s")); + EXPECT_THAT(output,testing::HasSubstr("PER/%")); ifs.close(); } diff --git a/source/module_base/timer.cpp b/source/module_base/timer.cpp index cede14041f..c47d58f2a9 100644 --- a/source/module_base/timer.cpp +++ b/source/module_base/timer.cpp @@ -275,7 +275,7 @@ void timer::print_all(std::ofstream &ofs) assert(class_names.size() == calls.size()); assert(class_names.size() == avgs.size()); assert(class_names.size() == pers.size()); - std::vector titles = {"CLASS_NAME", "NAME", "TIME(Sec)", "CALLS", "AVG(Sec)", "PER(%)"}; + std::vector titles = {"CLASS_NAME", "NAME", "TIME/s", "CALLS", "AVG/s", "PER/%"}; std::vector formats = {"%-10s", "%-10s", "%6.2f", "%8d", "%6.2f", "%6.2f"}; FmtTable time_statistics(titles, pers.size(), formats, {FmtTable::Align::LEFT, FmtTable::Align::CENTER}); time_statistics << class_names << names << times << calls << avgs << pers; diff --git a/source/module_elecstate/elecstate_print.cpp b/source/module_elecstate/elecstate_print.cpp index b57c5f2712..4ce3b4c246 100644 --- a/source/module_elecstate/elecstate_print.cpp +++ b/source/module_elecstate/elecstate_print.cpp @@ -10,6 +10,78 @@ #include "module_base/formatter.h" namespace elecstate { + /** + * Notes on refactor of ESolver's functions + * + * the print of SCF iteration on-the-fly information. + * 1. Previously it is expected for nspin 1, 2, and 4, also with xc_type 3/5 or not, the information will organized in different ways. + * This brings inconsistencies between patterns of print and make it hard to vectorize information. + * 2. the function print_etot actually do two kinds of things, 1) print information into running_*.log, 2) print information onto + * screen. These two tasks are, in no way should be placed/implemented in one function directly + * 3. there are information redundance: the istep of SCF can provide information determing whether print out the SCF iteration info. + * table header or not, rather than dividing into two functions and hard code the format. + * + * For nspin 1, print: ITER, ETOT, EDIFF, DRHO, TIME + * nspin 2, print: ITER, TMAG, AMAG, ETOT, EDIFF, DRHO, TIME + * nspin 4 with nlcc, print: ITER, TMAGX, TMAGY, TMAGZ, AMAG, ETOT, EDIFF, DRHO, TIME + * xc type_id 3/5: DKIN + * + * Based on summary above, there are several groups of info: + * 1. counting: ITER + * 2. (optional) magnetization: TMAG or TMAGX-TMAGY-TMAGZ, AMAG + * 3. energies: ETOT, EDIFF + * 4. densities: DRHO, DKIN(optional) + * 5. time: TIME + */ + void print_scf_iterinfo(const std::string& ks_solver, const int& istep, const int& witer, + const std::vector& mag, const int& wmag, + const double& etot, const double& ediff, const int& wener, + const std::vector& drho, const int& wrho, + const double& time, const int& wtime) + { + std::map iter_header_dict = { + {"cg", "CG"}, {"cg_in_lcao", "CG"}, {"lapack", "LA"}, {"genelpa", "GE"}, + {"dav", "DA"}, {"dav_subspace", "DS"}, {"scalapack_gvx", "GV"}, {"cusolver", "CU"}, + {"bpcg", "BP"}, {"pexsi", "PE"} + }; // I change the key of "cg_in_lcao" to "CG" because all the other are only two letters + // ITER column + std::vector th_fmt = {" %-" + std::to_string(witer) + "s"}; // table header: th: ITER + std::vector td_fmt = {" " + iter_header_dict[ks_solver] + "%-" + std::to_string(witer - 2) + ".0f"}; // table data: td: GE10086 + // magnetization column, might be non-exist, but size of mag can only be 0, 2 or 4 + for(int i = 0; i < mag.size(); i++) {th_fmt.emplace_back("%" + std::to_string(wmag) + "s");} + for(int i = 0; i < mag.size(); i++) {td_fmt.emplace_back("%" + std::to_string(wmag) + ".4e");} // hard-code precision here + // energies + for(int i = 0; i < 2; i++) {th_fmt.emplace_back("%" + std::to_string(wener) + "s");} + for(int i = 0; i < 2; i++) {td_fmt.emplace_back("%" + std::to_string(wener) + ".8e");} + // densities column, size can be 1 or 2, DRHO or DRHO, DKIN + for(int i = 0; i < drho.size(); i++) {th_fmt.emplace_back("%" + std::to_string(wrho) + "s");} + for(int i = 0; i < drho.size(); i++) {td_fmt.emplace_back("%" + std::to_string(wrho) + ".8e");} + // time column, trivial + th_fmt.emplace_back("%" + std::to_string(wtime) + "s\n"); + td_fmt.emplace_back("%" + std::to_string(wtime) + ".2f\n"); + // contents + std::vector titles; std::vector values; + switch (mag.size()) + { + case 2: + titles = {"ITER", FmtCore::center("TMAG", wmag), FmtCore::center("AMAG", wmag), + FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)}; + values = {double(istep), mag[0], mag[1], etot, ediff, drho[0]}; break; + case 4: + titles = {"ITER", FmtCore::center("TMAGX", wmag), FmtCore::center("TMAGY", wmag), FmtCore::center("TMAGZ", wmag), + FmtCore::center("AMAG", wmag), FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)}; + values = {double(istep), mag[0], mag[1], mag[2], mag[3], etot, ediff, drho[0]}; break; + default: + titles = {"ITER", FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)}; + values = {double(istep), etot, ediff, drho[0]}; break; + } + if(drho.size() > 1) {titles.push_back(FmtCore::center("DKIN", wrho)); values.push_back(drho[1]);} + titles.push_back(FmtCore::center("TIME/s", wtime)); values.push_back(time); + std::string buf; + if(istep == 1) { for(int i = 0; i < titles.size(); i++) {buf += FmtCore::format(th_fmt[i].c_str(), titles[i]);} } + for(int i = 0; i < values.size(); i++) {buf += FmtCore::format(td_fmt[i].c_str(), values[i]);} + std::cout << buf; + } /// @brief print and check for band energy and occupations /// @param ofs void ElecState::print_eigenvalue(std::ofstream& ofs) @@ -259,144 +331,21 @@ void ElecState::print_etot(const bool converged, this->f_en.etot_old = this->f_en.etot; } this->f_en.etot_delta = this->f_en.etot - this->f_en.etot_old; - - // mohan update 2011-02-26 - std::stringstream ss; - - // xiaohui add 2013-09-02, Peize Lin update 2020.11.14 - std::string label; - std::string ks_solver_type = get_ks_solver_type(); - if (ks_solver_type == "cg") - { - label = "CG"; - } - else if (ks_solver_type == "cg_in_lcao") - { - label = "CGAO"; - } - else if (ks_solver_type == "lapack") - { - label = "LA"; - } - else if (ks_solver_type == "genelpa") - { - label = "GE"; - } - else if (ks_solver_type == "dav") - { - label = "DA"; - } - else if (ks_solver_type == "dav_subspace") - { - label = "DS"; - } - else if (ks_solver_type == "scalapack_gvx") - { - label = "GV"; - } - else if (ks_solver_type == "cusolver") - { - label = "CU"; - } - else if (ks_solver_type == "bpcg") - { - label = "BP"; - } - else if (ks_solver_type == "pexsi") - { - label = "PE"; - } - else - { - ModuleBase::WARNING_QUIT("Energy", "print_etot found unknown ks_solver_type"); - } - ss << label << iter; - // xiaohui add 2013-09-02 - - bool scientific = true; - int prec = 6; - - if (!print) - return; - if (GlobalV::OUT_LEVEL == "ie" || GlobalV::OUT_LEVEL == "m") // xiaohui add 'm' option, 2015-09-16 { - std::cout << " " << std::setw(7) << ss.str(); - // std::cout << std::setiosflags(ios::fixed); - // std::cout << std::setiosflags(ios::showpos); - if (scientific) - { - std::cout << std::scientific; - } - - if (GlobalV::COLOUR) - { - if (GlobalV::MY_RANK == 0) - { - printf("\e[36m%-15f\e[0m", this->f_en.etot); - if (GlobalV::NSPIN == 2) - { - std::cout << std::setprecision(2); - std::cout << std::setw(10) << get_ucell_tot_magnetization(); - std::cout << std::setw(10) << get_ucell_abs_magnetization(); - } - else if (GlobalV::NSPIN == 4 && GlobalV::NONCOLIN) - { - std::cout << std::setprecision(2); - std::cout << std::setw(10) << get_ucell_tot_magnetization_nc_x() << std::setw(10) - << get_ucell_tot_magnetization_nc_y() << std::setw(10) - << get_ucell_tot_magnetization_nc_z(); - std::cout << std::setw(10) << get_ucell_abs_magnetization(); - } - if (scf_thr > 1.0) - { - // 31 is red - printf("\e[31m%-14e\e[0m", scf_thr); - // printf( "[31m%-14e[0m", scf_thr); - } - else - { - // 32 is green - printf("\e[32m%-14e\e[0m", scf_thr); - // printf( "[32m%-14e[0m", scf_thr); - } - // 34 is blue - printf("\e[36m%-15f\e[0m", this->f_en.etot * ModuleBase::Ry_to_eV); - std::cout << std::setprecision(3); - std::cout << std::resetiosflags(std::ios::scientific); - - std::cout << std::setw(11) << duration; - std::cout << std::endl; - } - } - else + std::vector mag; + switch (GlobalV::NSPIN) { - std::cout << std::setprecision(prec); - if (GlobalV::NSPIN == 2) - { - std::cout << std::setprecision(2); - std::cout << std::setw(10) << get_ucell_tot_magnetization(); - std::cout << std::setw(10) << get_ucell_abs_magnetization(); - } - std::cout << std::setprecision(6); - std::cout << std::setw(15) << this->f_en.etot * ModuleBase::Ry_to_eV; - std::cout << std::setw(15) << this->f_en.etot_delta * ModuleBase::Ry_to_eV; - std::cout << std::setprecision(3); - std::cout << std::setw(11) << scf_thr; - if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5) - { - std::cout << std::setprecision(3); - std::cout << std::setw(11) << scf_thr_kin; - } - std::cout << std::setprecision(3); - std::cout << std::setw(11) << duration; - std::cout << std::endl; + case 2: mag = {get_ucell_tot_magnetization(), get_ucell_abs_magnetization()}; break; + case 4: mag = {get_ucell_tot_magnetization_nc_x(), get_ucell_tot_magnetization_nc_y(), + get_ucell_tot_magnetization_nc_z(), get_ucell_abs_magnetization()}; break; + default: mag = {}; break; } + std::vector drho = {scf_thr}; + if(elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5) {drho.push_back(scf_thr_kin);} + elecstate::print_scf_iterinfo(get_ks_solver_type(), iter, 6, mag, 12, this->f_en.etot * ModuleBase::Ry_to_eV, + this->f_en.etot_delta * ModuleBase::Ry_to_eV, 16, drho, 16, duration, 6); } - else - { - } - this->f_en.etot_old = this->f_en.etot; return; } @@ -414,5 +363,4 @@ void ElecState::print_format(const std::string& name, const double& value) GlobalV::ofs_running << std::resetiosflags(std::ios::showpos); return; } - } // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/potentials/H_TDDFT_pw.cpp b/source/module_elecstate/potentials/H_TDDFT_pw.cpp index 93601bac40..6f37cd87d3 100644 --- a/source/module_elecstate/potentials/H_TDDFT_pw.cpp +++ b/source/module_elecstate/potentials/H_TDDFT_pw.cpp @@ -38,7 +38,7 @@ double H_TDDFT_pw::lcut1; double H_TDDFT_pw::lcut2; //velocity gauge -double H_TDDFT_pw::At[3]={0.0,0.0,0.0}; +ModuleBase::Vector3 H_TDDFT_pw::At; // time domain parameters diff --git a/source/module_elecstate/potentials/H_TDDFT_pw.h b/source/module_elecstate/potentials/H_TDDFT_pw.h index ea6b4f0ef9..8b3ffcaf42 100644 --- a/source/module_elecstate/potentials/H_TDDFT_pw.h +++ b/source/module_elecstate/potentials/H_TDDFT_pw.h @@ -54,7 +54,7 @@ class H_TDDFT_pw : public PotBase static double lcut2; //velocity gauge, vector magnetic potential - static double At[3]; + static ModuleBase::Vector3 At; // time domain parameters diff --git a/source/module_elecstate/test/elecstate_print_test.cpp b/source/module_elecstate/test/elecstate_print_test.cpp index 69f93df393..555103b473 100644 --- a/source/module_elecstate/test/elecstate_print_test.cpp +++ b/source/module_elecstate/test/elecstate_print_test.cpp @@ -392,36 +392,36 @@ TEST_F(ElecStatePrintTest, PrintEtotColorS4) elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print); } -TEST_F(ElecStatePrintTest, PrintEtotWarning) -{ - GlobalV::ofs_running.open("test.dat", std::ios::out); - bool converged = false; - int iter = 1; - double scf_thr = 0.1; - double scf_thr_kin = 0.0; - double duration = 2.0; - int printe = 0; - double pw_diag_thr = 0.1; - int avg_iter = 2; - bool print = true; - elecstate.charge = new Charge; - elecstate.charge->nrxx = 100; - elecstate.charge->nxyz = 1000; - GlobalV::imp_sol = true; - GlobalV::EFIELD_FLAG = true; - GlobalV::GATE_FLAG = true; - GlobalV::TWO_EFERMI = false; - GlobalV::out_bandgap = true; - GlobalV::COLOUR = false; - GlobalV::MY_RANK = 0; - GlobalV::BASIS_TYPE = "pw"; - GlobalV::SCF_NMAX = 100; - elecstate::tmp_ks_solver = "unknown"; - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("print_etot found unknown ks_solver_type")); - GlobalV::ofs_running.close(); - delete elecstate.charge; - std::remove("test.dat"); -} +// TEST_F(ElecStatePrintTest, PrintEtotWarning) +// { +// GlobalV::ofs_running.open("test.dat", std::ios::out); +// bool converged = false; +// int iter = 1; +// double scf_thr = 0.1; +// double scf_thr_kin = 0.0; +// double duration = 2.0; +// int printe = 0; +// double pw_diag_thr = 0.1; +// int avg_iter = 2; +// bool print = true; +// elecstate.charge = new Charge; +// elecstate.charge->nrxx = 100; +// elecstate.charge->nxyz = 1000; +// GlobalV::imp_sol = true; +// GlobalV::EFIELD_FLAG = true; +// GlobalV::GATE_FLAG = true; +// GlobalV::TWO_EFERMI = false; +// GlobalV::out_bandgap = true; +// GlobalV::COLOUR = false; +// GlobalV::MY_RANK = 0; +// GlobalV::BASIS_TYPE = "pw"; +// GlobalV::SCF_NMAX = 100; +// elecstate::tmp_ks_solver = "unknown"; +// testing::internal::CaptureStdout(); +// EXPECT_EXIT(elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print), ::testing::ExitedWithCode(0), ""); +// output = testing::internal::GetCapturedStdout(); +// EXPECT_THAT(output, testing::HasSubstr("print_etot found unknown ks_solver_type")); +// GlobalV::ofs_running.close(); +// delete elecstate.charge; +// std::remove("test.dat"); +// } diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index cb608a142f..d0f63bf67f 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -455,16 +455,17 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT SCF"); // 3) print head - if(this->maxniter > 0) - { - this->print_head(); //print the headline on the screen. - } + // if(this->maxniter > 0) + // { + // this->print_head(); //print the headline on the screen. + // } bool firstscf = true; this->conv_elec = false; this->niter = this->maxniter; // 4) SCF iterations + std::cout << " * * * * * *\n << Start SCF iteration." << std::endl; for (int iter = 1; iter <= this->maxniter; ++iter) { // 5) write head @@ -622,7 +623,7 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) std::cout<<" SCF restart after this step!"<> Leave SCF iteration.\n * * * * * *" << std::endl; #ifdef __RAPIDJSON // 14) add Json of efermi energy converge diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index bfd3fd2e63..b63fea304c 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -16,7 +16,6 @@ #include "module_io/cal_test.h" #include "module_io/output_potential.h" #include "module_io/output_rho.h" - namespace ModuleESolver { @@ -101,7 +100,6 @@ class ESolver_KS : public ESolver_FP const double duration, const double ethr); - // Write the headline in the running_log file // "PW/LCAO" ALGORITHM --------------- ION= 1 ELEC= 1-------------------------------- void write_head( @@ -140,7 +138,6 @@ class ESolver_KS : public ESolver_FP std::string basisname; //PW or LCAO void print_wfcfft(Input& inp, std::ofstream &ofs); -}; - +}; } // end of namespace #endif diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp index 4e98f79458..022b87e269 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp @@ -396,28 +396,29 @@ void Force_Stress_LCAO::getForceStress(const bool isforce, { GlobalC::ld.save_npy_f(fcs - GlobalC::ld.F_delta, "f_base.npy", GlobalC::ucell.nat); // Ry/Bohr, F_base - if (GlobalV::GAMMA_ONLY_LOCAL) + if(!GlobalV::deepks_equiv) //training with force label not supported by equivariant version now { - const std::vector>& dm_gamma - = dynamic_cast*>(pelec)->get_DM()->get_DMK_vector(); - GlobalC::ld.cal_gdmx(dm_gamma[0], GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, isstress); - } - else - { - const std::vector>>& dm_k - = dynamic_cast>*>(pelec) - ->get_DM() - ->get_DMK_vector(); - GlobalC::ld - .cal_gdmx_k(dm_k, GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, kv.nks, kv.kvec_d, isstress); - } - if (GlobalV::deepks_out_unittest) - GlobalC::ld.check_gdmx(GlobalC::ucell.nat); - GlobalC::ld.cal_gvx(GlobalC::ucell.nat); + if (GlobalV::GAMMA_ONLY_LOCAL) + { + const std::vector>& dm_gamma + = dynamic_cast*>(pelec)->get_DM()->get_DMK_vector(); + GlobalC::ld.cal_gdmx(dm_gamma[0], GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, isstress); + } + else + { + const std::vector>>& dm_k + = dynamic_cast>*>(pelec) + ->get_DM() + ->get_DMK_vector(); + GlobalC::ld + .cal_gdmx_k(dm_k, GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, kv.nks, kv.kvec_d, isstress); + } + if (GlobalV::deepks_out_unittest) { GlobalC::ld.check_gdmx(GlobalC::ucell.nat); } + GlobalC::ld.cal_gvx(GlobalC::ucell.nat); - if (GlobalV::deepks_out_unittest) - GlobalC::ld.check_gvx(GlobalC::ucell.nat); - GlobalC::ld.save_npy_gvx(GlobalC::ucell.nat); // /Bohr, grad_vx + if (GlobalV::deepks_out_unittest) { GlobalC::ld.check_gvx(GlobalC::ucell.nat); } + GlobalC::ld.save_npy_gvx(GlobalC::ucell.nat); // /Bohr, grad_vx + } } else { @@ -606,8 +607,11 @@ void Force_Stress_LCAO::getForceStress(const bool isforce, GlobalC::ld.save_npy_s(scs, "s_tot.npy", GlobalC::ucell.omega); // change to energy unit Ry when printing, S_tot, w/ model - GlobalC::ld.cal_gvepsl(GlobalC::ucell.nat); - GlobalC::ld.save_npy_gvepsl(GlobalC::ucell.nat); // unitless, grad_vepsl + if(!GlobalV::deepks_equiv) //training with stress label not supported by equivariant version now + { + GlobalC::ld.cal_gvepsl(GlobalC::ucell.nat); + GlobalC::ld.save_npy_gvepsl(GlobalC::ucell.nat); // unitless, grad_vepsl + } } else { diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp index 577475f6d8..f75e0e9d85 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp @@ -337,7 +337,7 @@ void hamilt::DeePKS>::calculate_HR() std::vector trace_alpha_row; std::vector trace_alpha_col; std::vector gedms; - if(!GlobalC::ld.get_if_equiv()) + if(!GlobalV::deepks_equiv) { int ib=0; for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0) diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/td_ekinetic_lcao.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/td_ekinetic_lcao.cpp index 4102aae997..03bfdfdda5 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/td_ekinetic_lcao.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/td_ekinetic_lcao.cpp @@ -60,7 +60,7 @@ void TDEkinetic, double>>::td_ekinetic_scalar( template void TDEkinetic>::td_ekinetic_grad(std::complex* Hloc, int nnr, ModuleBase::Vector3 grad_overlap){ std::complex tmp= {0, grad_overlap*cart_At}; - Hloc[nnr] += tmp; + Hloc[nnr] -= tmp; return; } @@ -193,10 +193,9 @@ void TDEkinetic>::init_td(void) { TD_Velocity::td_vel_op = &td_velocity; //calculate At in cartesian coorinates. - double l_norm[3]={this->ucell->a1.norm() ,this->ucell->a2.norm() ,this->ucell->a3.norm()}; - double (&A)[3] = elecstate::H_TDDFT_pw::At; - cart_At = this->ucell->a1*A[0]/l_norm[0] + this->ucell->a2*A[1]/l_norm[1] + this->ucell->a3*A[2]/l_norm[2]; - std::cout << "cart_At: " << cart_At[0] << " " <ucell->a1, this->ucell->a2, this->ucell->a3, elecstate::H_TDDFT_pw::At); + this->cart_At = td_velocity.cart_At; + std::cout<<"cart_At: "< void hamilt::TDNonlocal>::init_td(void) { //calculate At in cartesian coorinates. - double l_norm[3]={this->ucell->a1.norm() ,this->ucell->a2.norm() ,this->ucell->a3.norm()}; - double (&A)[3] = elecstate::H_TDDFT_pw::At; - cart_At = -(this->ucell->a1*A[0]/l_norm[0] + this->ucell->a2*A[1]/l_norm[1] + this->ucell->a3*A[2]/l_norm[2]); - std::cout << "cart_At: " << cart_At[0] << " " <cart_At=TD_Velocity::td_vel_op->cart_At; } // initialize_HR() template @@ -207,7 +204,7 @@ void hamilt::TDNonlocal>::calculate_HR() atom1->iw2n[iw1], tau0 * this->ucell->lat0, T0, - -cart_At, + -cart_At/2.0, 0); #else uot.snap_psibeta_half_tddft(orb, @@ -220,7 +217,7 @@ void hamilt::TDNonlocal>::calculate_HR() atom1->iw2n[iw1], tau0 * this->ucell->lat0, T0, - -cart_At, + -cart_At/2.0, 0); #endif nlm_tot[ad].insert({all_indexes[iw1l], nlm[0]}); diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.cpp index 8829bef774..162f7857e5 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.cpp @@ -376,4 +376,4 @@ void sparse_format::clear_zero_elements( } return; -} +} \ No newline at end of file diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.h b/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.h index d7b36f397f..d46a0304d3 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.h +++ b/source/module_hamilt_lcao/hamilt_lcaodft/spar_hsr.h @@ -46,4 +46,4 @@ namespace sparse_format } -#endif +#endif \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp b/source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp index a8b3d8cfcd..1b32601a27 100644 --- a/source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp +++ b/source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp @@ -89,7 +89,7 @@ void LCAO_Deepks::init( int tot_inl = tot_inl_per_atom * nat; - if(if_equiv) tot_inl = nat; + if(GlobalV::deepks_equiv) tot_inl = nat; this->lmaxd = lm; this->nmaxd = nm; @@ -99,7 +99,7 @@ void LCAO_Deepks::init( int pdm_size = 0; this->inlmax = tot_inl; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { GlobalV::ofs_running << " total basis (all atoms) for descriptor= " << std::endl; @@ -125,7 +125,7 @@ void LCAO_Deepks::init( } // cal n(descriptor) per atom , related to Lmax, nchi(L) and m. (not total_nchi!) - if(!if_equiv) + if(!GlobalV::deepks_equiv) { this->des_per_atom=0; // mohan add 2021-04-21 for (int l = 0; l <= this->lmaxd; l++) @@ -215,7 +215,7 @@ void LCAO_Deepks::init_gdmx(const int nat) this->gdmy = new double** [nat]; this->gdmz = new double** [nat]; int pdm_size = 0; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1); } @@ -269,7 +269,7 @@ void LCAO_Deepks::init_gdmepsl() this->gdm_epsl = new double** [6]; int pdm_size = 0; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1); } @@ -328,7 +328,7 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks) //init gedm** int pdm_size = 0; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1); } diff --git a/source/module_hamilt_lcao/module_deepks/LCAO_deepks.h b/source/module_hamilt_lcao/module_deepks/LCAO_deepks.h index 1b9c4a668e..481a8612d7 100644 --- a/source/module_hamilt_lcao/module_deepks/LCAO_deepks.h +++ b/source/module_hamilt_lcao/module_deepks/LCAO_deepks.h @@ -82,7 +82,6 @@ class LCAO_Deepks int get_inl(const int& T0, const int& I0, const int& L0, const int& N0) { return inl_index[T0](I0, L0, N0); } const double* get_gedms(const int& inl){ return gedm[inl]; } - bool get_if_equiv(){return if_equiv;} int get_lmaxd(){return lmaxd;} //------------------- // private variables @@ -96,8 +95,6 @@ class LCAO_Deepks int nks_V_delta = 0; bool init_pdm = false; //for DeePKS NSCF calculation - - bool if_equiv = false; //equivariant version // deep neural network module that provides corrected Hamiltonian term and // related derivatives. diff --git a/source/module_hamilt_lcao/module_deepks/LCAO_deepks_fdelta.cpp b/source/module_hamilt_lcao/module_deepks/LCAO_deepks_fdelta.cpp index 3bdafecf04..5adf7f5195 100644 --- a/source/module_hamilt_lcao/module_deepks/LCAO_deepks_fdelta.cpp +++ b/source/module_hamilt_lcao/module_deepks/LCAO_deepks_fdelta.cpp @@ -133,7 +133,7 @@ void LCAO_Deepks::cal_f_delta_gamma(const std::vector>& dm, assert(nlm1.size()==nlm2[0].size()); - if(!if_equiv) + if(!GlobalV::deepks_equiv) { int ib=0; for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0) @@ -196,7 +196,7 @@ void LCAO_Deepks::cal_f_delta_gamma(const std::vector>& dm, assert(nlm1.size()==nlm2[0].size()); - if(!if_equiv) + if(!GlobalV::deepks_equiv) { int ib=0; for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0) @@ -378,7 +378,7 @@ void LCAO_Deepks::cal_f_delta_k(const std::vector npy_des; for (int inl = 0;inl < inlmax;++inl) diff --git a/source/module_hamilt_lcao/module_deepks/LCAO_deepks_pdm.cpp b/source/module_hamilt_lcao/module_deepks/LCAO_deepks_pdm.cpp index 47da26b27f..be06283cd9 100644 --- a/source/module_hamilt_lcao/module_deepks/LCAO_deepks_pdm.cpp +++ b/source/module_hamilt_lcao/module_deepks/LCAO_deepks_pdm.cpp @@ -35,7 +35,7 @@ void LCAO_Deepks::cal_projected_DM(const elecstate::DensityMatrixlmaxd * 2 + 1) * (this->lmaxd * 2 + 1); } @@ -93,7 +93,7 @@ void LCAO_Deepks::cal_projected_DM(const elecstate::DensityMatrix trace_alpha_row; std::vector trace_alpha_col; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { int ib=0; for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0) @@ -231,7 +231,7 @@ void LCAO_Deepks::cal_projected_DM(const elecstate::DensityMatrixlmaxd * 2 + 1) * (this->lmaxd * 2 + 1); } @@ -350,7 +350,7 @@ void LCAO_Deepks::cal_projected_DM_k(const elecstate::DensityMatrix trace_alpha_row; std::vector trace_alpha_col; - if(!if_equiv) + if(!GlobalV::deepks_equiv) { int ib=0; for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0) @@ -495,7 +495,7 @@ void LCAO_Deepks::cal_projected_DM_k(const elecstate::DensityMatrixcal_descriptor_equiv(nat); return; @@ -113,7 +113,7 @@ void LCAO_Deepks::check_descriptor(const UnitCell &ucell) if(GlobalV::MY_RANK!=0) return; std::ofstream ofs("descriptor.dat"); ofs<cal_gedm_equiv(nat); return; diff --git a/source/module_hamilt_lcao/module_gint/CMakeLists.txt b/source/module_hamilt_lcao/module_gint/CMakeLists.txt index e0fb786318..5cd0ca654e 100644 --- a/source/module_hamilt_lcao/module_gint/CMakeLists.txt +++ b/source/module_hamilt_lcao/module_gint/CMakeLists.txt @@ -2,7 +2,6 @@ list(APPEND objects gint.cpp - gint_gamma.cpp gint_gamma_env.cpp gint_gamma_vl.cpp gint_fvl.cpp @@ -25,13 +24,23 @@ list(APPEND objects if(USE_CUDA) list(APPEND objects kernels/cuda/cuda_tools.cu - kernels/cuda/vbatch_matrix_mul.cu kernels/cuda/gint_vl.cu kernels/cuda/gint_rho.cu kernels/cuda/gint_force.cu gint_vl_gpu.cu gint_rho_gpu.cu gint_force_gpu.cu + kernels/cuda/gemm_selector.cu + kernels/cuda/code_gen_00.cu + kernels/cuda/code_gen_01.cu + kernels/cuda/code_gen_02.cu + kernels/cuda/code_gen_03.cu + kernels/cuda/code_gen_04.cu + kernels/cuda/code_gen_05.cu + kernels/cuda/code_gen_06.cu + kernels/cuda/code_gen_07.cu + kernels/cuda/code_gen_08.cu + kernels/cuda/code_gen_09.cu gtask_vl.cpp gtask_rho.cpp gtask_force.cpp diff --git a/source/module_hamilt_lcao/module_gint/gint_gamma.cpp b/source/module_hamilt_lcao/module_gint/gint_gamma.cpp deleted file mode 100644 index 7b2683d807..0000000000 --- a/source/module_hamilt_lcao/module_gint/gint_gamma.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include "gint_gamma.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "module_base/ylm.h" -#include "module_cell/module_neighbor/sltk_atom_arrange.h" -#include "module_base/timer.h" -#ifdef _OPENMP -#include -#endif - -#ifdef __MKL -#include -#endif - -Gint_Gamma::Gint_Gamma() -{ - - sender_index_size = 1; - sender_local_index = nullptr; - sender_size_process = nullptr; - sender_displacement_process = nullptr; - sender_size=1; - sender_buffer=nullptr; - - receiver_index_size=1; - receiver_global_index = nullptr; - receiver_size_process = nullptr; - receiver_displacement_process = nullptr; - receiver_size=1; - receiver_buffer=nullptr; -} - -Gint_Gamma::~Gint_Gamma() -{ - // mohan add if 2024-04-09 - if(sender_local_index != nullptr) - { - delete[] sender_local_index; - } - - if(sender_size_process != nullptr) - { - delete[] sender_size_process; - } - - if(sender_displacement_process != nullptr) - { - delete[] sender_displacement_process; - } - - if(sender_buffer != nullptr) - { - delete[] sender_buffer; - } - - if(receiver_global_index != nullptr) - { - delete[] receiver_global_index; - } - - if(receiver_size_process != nullptr) - { - delete[] receiver_size_process; - } - - if(receiver_displacement_process != nullptr) - { - delete[] receiver_displacement_process; - } - - if(receiver_buffer != nullptr) - { - delete[] receiver_buffer; - } -} diff --git a/source/module_hamilt_lcao/module_gint/gint_gamma.h b/source/module_hamilt_lcao/module_gint/gint_gamma.h index 7ed700828f..275b8cebe0 100644 --- a/source/module_hamilt_lcao/module_gint/gint_gamma.h +++ b/source/module_hamilt_lcao/module_gint/gint_gamma.h @@ -24,8 +24,8 @@ class Gint_Gamma : public Gint { public: - Gint_Gamma(); - ~Gint_Gamma(); + // Gint_Gamma(); + // ~Gint_Gamma(); //------------------------------------------------------ // in gint_gamma_vl.cpp @@ -51,37 +51,6 @@ class Gint_Gamma : public Gint double*** DM = nullptr; //pointer to LOC.DM - ///------------------------------------------------------ - /// in gint_gamma_vl.cpp - ///------------------------------------------------------ - /// method for redistributing the Hamiltonian - /// from grid to 2D format - /// pass a setter function to customize row/col major and outputs - void vl_grid_to_2D(const double* vl_grid, - const Parallel_2D& p2d, - const int loc_grid_dim, - const bool new_e_iteration, - double* vl_2d, - std::function setfunc); - - ///=============================== - /// Use MPI_Alltoallv to convert a grid distributed matrix - /// to 2D - block cyclic distributed matrix. - ///=============================== - int sender_index_size; - int *sender_local_index; - int sender_size; - int *sender_size_process; - int *sender_displacement_process; - double* sender_buffer; - - int receiver_index_size; - int *receiver_global_index; - int receiver_size; - int *receiver_size_process; - int *receiver_displacement_process; - double* receiver_buffer; - }; #endif diff --git a/source/module_hamilt_lcao/module_gint/gint_gamma_vl.cpp b/source/module_hamilt_lcao/module_gint/gint_gamma_vl.cpp index 5177036c11..ce135db9b1 100644 --- a/source/module_hamilt_lcao/module_gint/gint_gamma_vl.cpp +++ b/source/module_hamilt_lcao/module_gint/gint_gamma_vl.cpp @@ -44,247 +44,6 @@ void Gint_Gamma::cal_vlocal(Gint_inout* inout,bool new_e_iteration) } } -#ifdef __MPI -//------------------------------------------------------------------ -// mohan add notes: 2021-03-11 -// this subroutine is used to transform data from grid integrals -// to 2D-block distribution -// s stands for 'sender' and r stands for 'receiver' -//------------------------------------------------------------------ -inline int setBufferParameter( - const Grid_Technique& gt, - MPI_Comm comm_2D, - int blacs_ctxt, - int nblk, - int& s_index_siz, - int*& s_local_index, - int*& s_siz_pro, - int*& s_dis_pro, - int& s_siz, - double*& s_buffer, - int& r_index_siz, - int*& r_global_index, - int*& r_siz_pro, - int*& r_dis_pro, - int& r_siz, - double*& r_buffer) -{ - const int nlocal = GlobalV::NLOCAL; - - //----------------------------------------- - // setup blacs parameters - //----------------------------------------- - int nprows=0; - int npcols=0; - int nprocs=0; - - int myprow=0; - int mypcol=0; - int myproc=0; - - Cblacs_gridinfo(blacs_ctxt, &nprows, &npcols, &myprow, &mypcol); - - //----------------------------------------- - // set index of current proor: myproc - // set number of total proors: nprocs - //----------------------------------------- - Cblacs_pinfo(&myproc, &nprocs); - - // initialize data arrays - delete[] s_siz_pro; - delete[] s_dis_pro; - delete[] r_siz_pro; - delete[] r_dis_pro; - - s_siz_pro=new int[nprocs]; - s_dis_pro=new int[nprocs]; - r_siz_pro=new int[nprocs]; - r_dis_pro=new int[nprocs]; - - //--------------------------------------------------------------------- - // build the local index to be sent to other pro (s_local_index), - // the global index to be received from other pro (r_global_index), - // the send/receive siz/dis for data exchange by MPI_Alltoall - //--------------------------------------------------------------------- - s_index_siz=gt.lgd*gt.lgd*2; - - delete[] s_local_index; - s_local_index=new int[s_index_siz]; - - int *s_global_index=new int[s_index_siz]; - - int pos=0; - s_siz_pro[0]=0; - for(int iproc=0; iproc= nlocal) - { - continue; - } - int lrow = gt.trace_lo[grow]; - if (lrow < 0) - { - continue; - } - - for(int icol=0, gcol=0; gcol= nlocal) - { - continue; - } - int lcol = gt.trace_lo[gcol]; - if (lcol < 0) - { - continue; - } - s_global_index[pos]=grow; - s_global_index[pos+1]=gcol; - s_local_index[pos]=lrow; - s_local_index[pos+1]=lcol; - pos+=2; - } - } - s_siz_pro[iproc]=pos-s_dis_pro[iproc]; - } - - MPI_Alltoall(s_siz_pro, 1, MPI_INT, - r_siz_pro, 1, MPI_INT, comm_2D); - - r_index_siz=r_siz_pro[0]; - r_dis_pro[0]=0; - for(int i=1; i setfunc) -{ - ModuleBase::timer::tick("Gint_Gamma","distri_vl"); - // setup send buffer and receive buffer size - // OUT(GlobalV::ofs_running, "Start transforming vlocal from grid distribute to 2D block"); - if(new_e_iteration) - { - ModuleBase::timer::tick("Gint_Gamma","distri_vl_index"); -#ifdef __MPI - setBufferParameter(*this->gridt, p2d.comm_2D, p2d.blacs_ctxt, p2d.nb, - this->sender_index_size, this->sender_local_index, - this->sender_size_process, this->sender_displacement_process, - this->sender_size, this->sender_buffer, - this->receiver_index_size, this->receiver_global_index, - this->receiver_size_process, this->receiver_displacement_process, - this->receiver_size, this->receiver_buffer); -#endif -#ifdef __DEBUG - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "vlocal exchange index is built"); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "buffer size(M):", - (this->sender_size+this->receiver_size)*sizeof(double)/1024/1024); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "buffer index size(M):", - (this->sender_index_size+this->receiver_index_size)*sizeof(int)/1024/1024); -#endif - ModuleBase::timer::tick("Gint_Gamma","distri_vl_index"); - } - - ModuleBase::timer::tick("Gint_Gamma","distri_vl_value"); - - // put data to send buffer - for(int i=0; isender_index_size; i+=2) - { - const int irow=this->sender_local_index[i]; - const int icol=this->sender_local_index[i+1]; - if(irow<=icol) - { - this->sender_buffer[i / 2] = vl_grid[irow * loc_grid_dim + icol]; - } - else - { - this->sender_buffer[i / 2] = vl_grid[icol * loc_grid_dim + irow]; - } - } - -#ifdef __DEBUG - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - "vlocal data are put in sender_buffer, size(M):", - this->sender_size*8/1024/1024); -#endif - - // use mpi_alltoall to get local data -#ifdef __MPI - MPI_Alltoallv(this->sender_buffer, this->sender_size_process, this->sender_displacement_process, MPI_DOUBLE, - this->receiver_buffer, this->receiver_size_process, - this->receiver_displacement_process, MPI_DOUBLE, p2d.comm_2D); -#endif - -#ifdef __DEBUG - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - "vlocal data are exchanged, received size(M):", - this->receiver_size*8/1024/1024); -#endif - - // put local data to H matrix - for(int i=0; ireceiver_index_size; i+=2) - { - const int g_row = this->receiver_global_index[i]; - const int g_col = this->receiver_global_index[i + 1]; - setfunc(g_row, g_col, this->receiver_buffer[i / 2], vl_2d); - } - - ModuleBase::timer::tick("Gint_Gamma","distri_vl_value"); - ModuleBase::timer::tick("Gint_Gamma","distri_vl"); -} #ifdef __MPI #include "module_hamilt_lcao/module_hcontainer/hcontainer_funcs.h" diff --git a/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu b/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu index a598720c8c..13ffe9b2d2 100644 --- a/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu +++ b/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu @@ -1,5 +1,4 @@ #include "kernels/cuda/cuda_tools.cuh" -#include "kernels/cuda/vbatch_matrix_mul.cuh" #include "module_base/ylm.h" #include "module_hamilt_lcao/module_gint/gint_rho.h" #include "module_hamilt_lcao/module_gint/gint_tools.h" diff --git a/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu b/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu index 2c3d5b3922..73d6c8d201 100644 --- a/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu +++ b/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu @@ -1,7 +1,6 @@ #include #include "kernels/cuda/cuda_tools.cuh" -#include "kernels/cuda/vbatch_matrix_mul.cuh" #include "module_base/ylm.h" #include "module_hamilt_lcao/module_gint/gint_tools.h" #include "module_hamilt_lcao/module_gint/gint_vl.h" diff --git a/source/module_hamilt_lcao/module_gint/grid_technique.h b/source/module_hamilt_lcao/module_gint/grid_technique.h index 03cca8e924..898b0ffb51 100644 --- a/source/module_hamilt_lcao/module_gint/grid_technique.h +++ b/source/module_hamilt_lcao/module_gint/grid_technique.h @@ -11,7 +11,7 @@ #include #include "kernels/cuda/cuda_tools.cuh" -#include "kernels/cuda/vbatch_matrix_mul.cuh" +#include "kernels/cuda/gemm_selector.cuh" #endif // Author: mohan diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cpp b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cpp index 4edfce05cb..42e8c4f0c5 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cpp +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cpp @@ -1,6 +1,3 @@ -// Generate and test the efficiency of matrix multiplication functions with different parameters -// This file takes a long time to compile - gemm_time_measure(max_m, max_n, d_m, @@ -4181,25 +4178,6 @@ gemm_time_measure(max_m, h_global_C, d_global_C); -gemm_time_measure(max_m, - max_n, - d_m, - d_n, - d_k, - d_global_A_array, - d_global_lda, - d_global_B_array, - d_global_ldb, - d_global_C_array, - d_global_ldc, - batchCount, - temp_stream, - fastest_time, - fastest_algo, - cpu_result, - h_global_C, - d_global_C); - gemm_time_measure(max_m, max_n, d_m, diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cuh new file mode 100644 index 0000000000..a4b1a75916 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen.cuh @@ -0,0 +1,473 @@ +#ifndef CODE_GEN_CUH +#define CODE_GEN_CUH + +#include "gemm_selector.cuh" +#include + +extern template void gemm_time_measure(int, int, int*, int*, int*, double**, int*, double**, int*, double**, int*, int, cudaStream_t, float&, matrix_multiple_func_type&, double*, double*, double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +extern template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +#endif \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_00.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_00.cu new file mode 100644 index 0000000000..a07c411485 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_00.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int, int, int*, int*, int*, double**, int*, double**, int*, double**, int*, int, cudaStream_t, float&, matrix_multiple_func_type&, double*, double*, double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_01.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_01.cu new file mode 100644 index 0000000000..9f725c23c6 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_01.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_02.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_02.cu new file mode 100644 index 0000000000..090eab0709 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_02.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_03.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_03.cu new file mode 100644 index 0000000000..046d0e5063 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_03.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_04.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_04.cu new file mode 100644 index 0000000000..f74209d829 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_04.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_05.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_05.cu new file mode 100644 index 0000000000..c9cb81bd7c --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_05.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_06.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_06.cu new file mode 100644 index 0000000000..f5fac39df2 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_06.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_07.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_07.cu new file mode 100644 index 0000000000..971c6eb0c0 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_07.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_08.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_08.cu new file mode 100644 index 0000000000..8643faae70 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_08.cu @@ -0,0 +1,48 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_09.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_09.cu new file mode 100644 index 0000000000..8cf333bf6f --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/code_gen_09.cu @@ -0,0 +1,53 @@ +#include "vbatch_matrix_mul.cuh" + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); + +template void gemm_time_measure(int,int,int*,int*,int*,double**,int*,double**,int*,double**,int*,int,cudaStream_t,float&,matrix_multiple_func_type&,double*,double*,double*); \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cu new file mode 100644 index 0000000000..cfad7440f3 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cu @@ -0,0 +1,138 @@ +#include + +#include "gemm_selector.cuh" +#include "vbatch_matrix_mul.cuh" +#include "cuda_tools.cuh" +#include "module_base/blas_connector.h" +#include "code_gen.cuh" + +/* + * Here we have utilized a very straightforward and brute-force method to select + * the optimal matrix multiplication kernel for a given scale of computation: we + * compute with all scales of kernels under the current computational task to + * find the fastest parameter combination. This approach can lead to an increase + * in compilation time. + */ +void gemm_algo_selector(int matrix_k, matrix_multiple_func_type& fastest_algo,const UnitCell& ucell) +{ + int batchCount_per_type = 32; + int batchCount + = batchCount_per_type * ucell.ntype * ucell.ntype; + + Cuda_Mem_Wrapper m(batchCount); + Cuda_Mem_Wrapper n(batchCount); + Cuda_Mem_Wrapper k(batchCount); + + int max_m = ucell.nwmax, max_n = ucell.nwmax; + + Cuda_Mem_Wrapper A(batchCount * max_m * matrix_k); + Cuda_Mem_Wrapper B(batchCount * max_n * matrix_k); + Cuda_Mem_Wrapper C(batchCount * max_m * max_n); + + Cuda_Mem_Wrapper lda(batchCount); + Cuda_Mem_Wrapper ldb(batchCount); + Cuda_Mem_Wrapper ldc(batchCount); + + Cuda_Mem_Wrapper A_array(batchCount); + Cuda_Mem_Wrapper B_array(batchCount); + Cuda_Mem_Wrapper C_array(batchCount); + + for (int i = 0; i < batchCount * max_m * matrix_k; ++i) + { + A.get_host_pointer()[i] = i * 0.001; + } + for (int i = 0; i < batchCount * max_n * matrix_k; ++i) + { + B.get_host_pointer()[i] = i * 0.002; + } + + double* cpu_result = new double[batchCount * max_m * max_n]; + memset(cpu_result, 0, batchCount * max_m * max_n * sizeof(double)); + int index = 0; + for (int i = 0; i < batchCount_per_type; ++i) + { + for (int j = 0; j < ucell.ntype; j++) + { + for (int l = 0; l < ucell.ntype; l++) + { + m.get_host_pointer()[index] = ucell.atoms[j].nw; + n.get_host_pointer()[index] = ucell.atoms[l].nw; + k.get_host_pointer()[index] = matrix_k; + + lda.get_host_pointer()[index] = matrix_k; + ldb.get_host_pointer()[index] = matrix_k; + ldc.get_host_pointer()[index] = ucell.atoms[l].nw; + + A_array.get_host_pointer()[index] + = &A.get_device_pointer()[index * max_m * matrix_k]; + B_array.get_host_pointer()[index] + = &B.get_device_pointer()[index * max_n * matrix_k]; + C_array.get_host_pointer()[index] + = &C.get_device_pointer()[index * max_n + * max_m]; // test atom add + BlasConnector::gemm( + 'N', + 'T', + m.get_host_pointer()[index], + n.get_host_pointer()[index], + matrix_k, + 1.0, + &A.get_host_pointer()[index * max_m * matrix_k], + matrix_k, + &B.get_host_pointer()[index * max_n * matrix_k], + matrix_k, + 1.0, + &cpu_result[index * max_m * max_n], + n.get_host_pointer()[index]); + index++; + } + } + } + + m.copy_host_to_device_sync(); + n.copy_host_to_device_sync(); + k.copy_host_to_device_sync(); + + lda.copy_host_to_device_sync(); + ldb.copy_host_to_device_sync(); + ldc.copy_host_to_device_sync(); + + A.copy_host_to_device_sync(); + B.copy_host_to_device_sync(); + A_array.copy_host_to_device_sync(); + B_array.copy_host_to_device_sync(); + C_array.copy_host_to_device_sync(); + + cudaStream_t temp_stream; + checkCuda(cudaStreamCreate(&temp_stream)); + + float fastest_time = 1000000; + fastest_algo = vbatched_gemm_impl; + + int* d_m = m.get_device_pointer(); + int* d_n = n.get_device_pointer(); + int* d_k = k.get_device_pointer(); + + double** d_global_A_array = A_array.get_device_pointer(); + double** d_global_B_array = B_array.get_device_pointer(); + double** d_global_C_array = C_array.get_device_pointer(); + + double* h_global_C = C.get_host_pointer(); + double* d_global_C = C.get_device_pointer(); + + int* d_global_lda = lda.get_device_pointer(); + int* d_global_ldb = ldb.get_device_pointer(); + int* d_global_ldc = ldc.get_device_pointer(); + +/* + * Please do not manually modify the code in the following file; + * it should simply be generated through a loop using a short Python program. + */ +#include "code_gen.cpp" + checkCuda(cudaStreamDestroy(temp_stream)); + std::cout << " gemm_algo_selector::Fastest time: " << fastest_time << " ms" + << std::endl; + // fastest_algo = vbatched_gemm_impl; + delete[] cpu_result; +} \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh new file mode 100644 index 0000000000..380a16c842 --- /dev/null +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh @@ -0,0 +1,24 @@ +#ifndef GEMM_SELECTOR_H +#define GEMM_SELECTOR_H + +#include "module_cell/unitcell.h" + +typedef std::function +matrix_multiple_func_type; + +void gemm_algo_selector(int k, matrix_multiple_func_type& func,const UnitCell& ucell); + +#endif \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cu b/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cu deleted file mode 100644 index b84b76a840..0000000000 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cu +++ /dev/null @@ -1,659 +0,0 @@ -#include - -#include "cuda_tools.cuh" -#include "module_base/blas_connector.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "vbatch_matrix_mul.cuh" - -#define sA(i, j) sA[(j)*slda + (i)] -#define sB(i, j) sB[(j)*sldb + (i)] -#define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] - -template -static __device__ void vbatched_gemm_device(int M, - int N, - int K, - T* __restrict__ A, - int LDA, - T* __restrict__ B, - int LDB, - T* __restrict__ C, - int LDC, - T* sA, - int slda, - T* sB, - int sldb, - T alpha) -{ - int idx = threadIdx.x; // thread's m dimension - int idy = threadIdx.y; // thread's n dimension - - int idt = DIM_X * idy + idx; // thread's global number - - int idxA = idt % DIM_XA; // idx within A - int idyA = idt / DIM_XA; // idy within A - - int idxB = idt % DIM_XB; // idx within B - int idyB = idt / DIM_XB; // idy within B - - int blx = blockIdx.x; // block's m dimension - int bly = blockIdx.y; // block's n dimension - - // Registers for the innermost loop - T rC[THR_N][THR_M]; - T rA[THR_M]; - T rB[THR_N]; - - // Registers for the dev->shmem copy - T ra[BLK_M / DIM_YA][BLK_K / DIM_XA]; - T rb[BLK_N / DIM_YB][BLK_K / DIM_XB]; - - // bound is the correction to offs_d in order to not get out of memory bound - // so bound could be negative value since offs_d could be out of bound - T* offs_dA = A + blx * BLK_M * LDA + idyA * LDA + idxA; - int boundA - = (LDA * (M - 1) + K) - (blx * BLK_M * LDA + idyA * LDA + idxA) - 1; - - T* offs_dB = B + bly * BLK_N * LDB + idyB * LDB + idxB; - int boundB - = (LDB * (N - 1) + K) - (bly * BLK_N * LDB + idyB * LDB + idxB) - 1; - - int m, n, k, kk; - -// Zero C -#pragma unroll - for (n = 0; n < THR_N; n++) - { -#pragma unroll - for (m = 0; m < THR_M; m++) - { - rC[n][m] = 0.0; - } - } - -// Load A dev->shmem -#pragma unroll - for (n = 0; n < BLK_M; n += DIM_YA) - { -#pragma unroll - for (m = 0; m < BLK_K; m += DIM_XA) - { - sA(n + idyA, m + idxA) = fetch(A, m, n, boundA); - } - } - -#pragma unroll - for (n = 0; n < BLK_N; n += DIM_YB) - { -#pragma unroll - for (m = 0; m < BLK_K; m += DIM_XB) - { - sB(m + idxB, n + idyB) = fetch(B, m, n, boundB); - } - } - - __syncthreads(); - - for (kk = 0; kk < K - BLK_K; kk += BLK_K) - { - offs_dA += BLK_K; - boundA -= BLK_K; - - offs_dB += BLK_K; - boundB -= BLK_K; - -// Load A dev->regs -#pragma unroll - for (n = 0; n < BLK_M / DIM_YA; n++) - { -#pragma unroll - for (m = 0; m < BLK_K / DIM_XA; m++) - { - ra[n][m] = fetch(A, m * DIM_XA, n * DIM_YA, boundA); - } - } - -// Load B dev->regs -#pragma unroll - for (n = 0; n < BLK_N / DIM_YB; n++) - { -#pragma unroll - for (m = 0; m < BLK_K / DIM_XB; m++) - { - rb[n][m] = fetch(B, m * DIM_XB, n * DIM_YB, boundB); - } - } - -// Multiply -#pragma unroll - for (k = 0; k < BLK_K; k++) - { -// Load A shmem->regs -#pragma unroll - for (m = 0; m < THR_M; m++) - { - rA[m] = sA(m * DIM_X + idx, k); - } - -// Load B shmem->regs -#pragma unroll - for (n = 0; n < THR_N; n++) - { - rB[n] = sB(k, n * DIM_Y + idy); - } - -// Compute -#pragma unroll - for (n = 0; n < THR_N; n++) - { -#pragma unroll - for (m = 0; m < THR_M; m++) - { - rC[n][m] += rA[m] * rB[n]; - } - } - } - - __syncthreads(); - -// Load A regs->shmem -#pragma unroll - for (n = 0; n < BLK_M / DIM_YA; n++) - { -#pragma unroll - for (m = 0; m < BLK_K / DIM_XA; m++) - { - sA(n * DIM_YA + idyA, m * DIM_XA + idxA) = ra[n][m]; - } - } - -// Load B regs->shmem -#pragma unroll - for (n = 0; n < BLK_N / DIM_YB; n++) - { -#pragma unroll - for (m = 0; m < BLK_K / DIM_XB; m++) - { - sB(m * DIM_XB + idxB, n * DIM_YB + idyB) = rb[n][m]; - } - } - __syncthreads(); - } - - // Multiply last full (BLK_K) or partial block of - // columns of op(A) and rows of op(B). - // It's okay that m,n exceed matrix bounds as all work is in registers - // or shared memory, and out-of-bounds rC[n][m] will not be saved later. - kk = K - kk; -#pragma unroll - for (k = 0; k < kk; k++) - { -// Load A shmem->regs -#pragma unroll - for (m = 0; m < THR_M; m++) - { - rA[m] = sA(m * DIM_X + idx, k); - } - -// Load B shmem->regs -#pragma unroll - for (n = 0; n < THR_N; n++) - { - rB[n] = sB(k, n * DIM_Y + idy); - } - -// Compute -#pragma unroll - for (n = 0; n < THR_N; n++) - { -#pragma unroll - for (m = 0; m < THR_M; m++) - { - rC[n][m] += rA[m] * rB[n]; - } - } - } - -// Store C regs->dev -#pragma unroll - for (n = 0; n < THR_N; n++) - { - int coord_dCn = bly * BLK_N + n * DIM_Y + idy; -#pragma unroll - for (m = 0; m < THR_M; m++) - { - int coord_dCm = blx * BLK_M + m * DIM_X + idx; - if (coord_dCm < M && coord_dCn < N) - { - int offsC = coord_dCn * LDC + coord_dCm; - - atomicAdd(C + offsC, rC[n][m] * alpha); - } - } - } -} - -/******************************************************************************/ -template -static __global__ void vbatched_gemm_kernel(int* M, - int* N, - int* K, - T** global_A_array, - int* global_lda, - T** global_B_array, - int* global_ldb, - T** global_C_array, - int* global_ldc, - T* alpha) -{ - extern __shared__ __align__(sizeof(T)) unsigned char smem[]; - T* shared_mem = reinterpret_cast(smem); - - int batchid = blockIdx.z; - int local_M = (int)M[batchid]; - int local_N = (int)N[batchid]; - int local_K = (int)K[batchid]; - - if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) - return; - if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) - return; - - int shared_lda = BLK_M + 1; - int shared_ldb = BLK_K + 1; - T* shared_A = (T*)shared_mem; - T* shared_B = shared_A + shared_lda * BLK_K; - double alpha_tmp = 1.0; - if (alpha != nullptr) - { - alpha_tmp = alpha[batchid]; - } - vbatched_gemm_device(local_M, - local_N, - local_K, - global_A_array[batchid], - (int)global_lda[batchid], - global_B_array[batchid], - (int)global_ldb[batchid], - global_C_array[batchid], - (int)global_ldc[batchid], - shared_A, - shared_lda, - shared_B, - shared_ldb, - alpha_tmp); -} - -static inline int ceildiv(int x, int y) -{ - return (x + y - 1) / y; -} - -template -void vbatched_gemm_impl(int max_m, - int max_n, - int* m, - int* n, - int* k, - T** global_A_array, - int* global_lda, - T** global_B_array, - int* global_ldb, - T** global_C_array, - int* global_ldc, - int batchCount, - cudaStream_t stream, - T* alpha) -{ - // The positions of A and B have been swapped here. - // This is because the original code is for column-major matrices. - // We use row-major matrices, so we need to swap A and B. - // The vbatched_gemm_impl is for C = trans(A) * B + C, but we need trans(C). - // Which means: trans(C) = trans(trans(A)*B + C) = trans(B) * A + trans(C) - // Then, ldc should be N, lda and ldb should be K - - size_t shared_mem_size = 0; - shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); - shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); - dim3 dimBlock(DIM_X, DIM_Y); - const int max_batch_count = 32768; - const int loop_num = batchCount / max_batch_count; - const int remain_num = batchCount % max_batch_count; - - for (int i = 0; i < loop_num; ++i) - { - dim3 dimGrid(ceildiv(max_n, BLK_M), - ceildiv(max_m, BLK_N), - max_batch_count); - T* alpha_tmp = nullptr; - if (alpha != nullptr) - { - alpha_tmp = alpha + i * max_batch_count; - } - - vbatched_gemm_kernel - <<>>( - n + i * max_batch_count, - m + i * max_batch_count, - k + i * max_batch_count, - global_B_array + i * max_batch_count, - global_ldb + i * max_batch_count, - global_A_array + i * max_batch_count, - global_lda + i * max_batch_count, - global_C_array + i * max_batch_count, - global_ldc + i * max_batch_count, - alpha_tmp); - checkCudaLastError(); - } - if (remain_num > 0) - { - dim3 dimGrid(ceildiv(max_n, BLK_M), ceildiv(max_m, BLK_N), remain_num); - T* alpha_tmp = nullptr; - if (alpha != nullptr) - { - alpha_tmp = alpha + loop_num * max_batch_count; - } - vbatched_gemm_kernel - <<>>( - n + loop_num * max_batch_count, - m + loop_num * max_batch_count, - k + loop_num * max_batch_count, - global_B_array + loop_num * max_batch_count, - global_ldb + loop_num * max_batch_count, - global_A_array + loop_num * max_batch_count, - global_lda + loop_num * max_batch_count, - global_C_array + loop_num * max_batch_count, - global_ldc + loop_num * max_batch_count, - alpha_tmp); - checkCudaLastError(); - } -} - -template -void gemm_time_measure(int max_m, - int max_n, - int* m, - int* n, - int* k, - T** global_A_array, - int* global_lda, - T** global_B_array, - int* global_ldb, - T** global_C_array, - int* global_ldc, - int batchCount, - cudaStream_t stream, - float& fast_time, - matrix_multiple_func_type& fastest_algo, - double* cpu_result, - double* h_global_C, - double* d_global_C) -{ - cudaEvent_t start, stop; - checkCuda( - cudaMemset(d_global_C, 0, batchCount * max_m * max_n * sizeof(double))); - checkCuda(cudaEventCreate(&start)); - checkCuda(cudaEventCreate(&stop)); - checkCuda(cudaEventRecord(start, stream)); - vbatched_gemm_impl(max_m, - max_n, - m, - n, - k, - global_A_array, - global_lda, - global_B_array, - global_ldb, - global_C_array, - global_ldc, - batchCount, - stream); - checkCuda(cudaEventRecord(stop, stream)); - cudaError_t cuda_status = cudaGetLastError(); - checkCuda(cudaStreamSynchronize(stream)); - float milliseconds = 0; - checkCuda(cudaEventElapsedTime(&milliseconds, start, stop)); - - // WARNING !!!!! Here we assume that all m and n are the same - checkCuda(cudaMemcpy(h_global_C, - d_global_C, - batchCount * max_m * max_n * sizeof(double), - cudaMemcpyDeviceToHost)); - bool check_result = true; - for (int i = 0; i < batchCount * max_m * max_n; ++i) - { - if (abs(cpu_result[i] - h_global_C[i]) > 0.001) - { - check_result = false; - break; - } - } - if (milliseconds < fast_time && cuda_status == cudaSuccess && check_result) - { - fast_time = milliseconds; - fastest_algo = vbatched_gemm_impl; -#ifdef __DEBUG - std::cout << "found! fastest time: " << fast_time << std::endl; - std::cout << DIM_X << "," << DIM_Y << "," << BLK_M << "," << BLK_N - << "," << BLK_K << "," << DIM_XA << "," << DIM_YA << "," - << DIM_XB << "," << DIM_YB << std::endl; -#endif - } -} - -/* - * Here we have utilized a very straightforward and brute-force method to select - * the optimal matrix multiplication kernel for a given scale of computation: we - * compute with all scales of kernels under the current computational task to - * find the fastest parameter combination. This approach can lead to an increase - * in compilation time (TODO: so in the future, it will be necessary to split - * this large section of code into multiple files, multiple compilation units). - */ -void gemm_algo_selector(int matrix_k, matrix_multiple_func_type& fastest_algo,const UnitCell& ucell) -{ - int batchCount_per_type = 32; - int batchCount - = batchCount_per_type * ucell.ntype * ucell.ntype; - - Cuda_Mem_Wrapper m(batchCount); - Cuda_Mem_Wrapper n(batchCount); - Cuda_Mem_Wrapper k(batchCount); - - int max_m = ucell.nwmax, max_n = ucell.nwmax; - - Cuda_Mem_Wrapper A(batchCount * max_m * matrix_k); - Cuda_Mem_Wrapper B(batchCount * max_n * matrix_k); - Cuda_Mem_Wrapper C(batchCount * max_m * max_n); - - Cuda_Mem_Wrapper lda(batchCount); - Cuda_Mem_Wrapper ldb(batchCount); - Cuda_Mem_Wrapper ldc(batchCount); - - Cuda_Mem_Wrapper A_array(batchCount); - Cuda_Mem_Wrapper B_array(batchCount); - Cuda_Mem_Wrapper C_array(batchCount); - - for (int i = 0; i < batchCount * max_m * matrix_k; ++i) - { - A.get_host_pointer()[i] = i * 0.001; - } - for (int i = 0; i < batchCount * max_n * matrix_k; ++i) - { - B.get_host_pointer()[i] = i * 0.002; - } - - double* cpu_result = new double[batchCount * max_m * max_n]; - memset(cpu_result, 0, batchCount * max_m * max_n * sizeof(double)); - int index = 0; - for (int i = 0; i < batchCount_per_type; ++i) - { - for (int j = 0; j < ucell.ntype; j++) - { - for (int l = 0; l < ucell.ntype; l++) - { - m.get_host_pointer()[index] = ucell.atoms[j].nw; - n.get_host_pointer()[index] = ucell.atoms[l].nw; - k.get_host_pointer()[index] = matrix_k; - - lda.get_host_pointer()[index] = matrix_k; - ldb.get_host_pointer()[index] = matrix_k; - ldc.get_host_pointer()[index] = ucell.atoms[l].nw; - - A_array.get_host_pointer()[index] - = &A.get_device_pointer()[index * max_m * matrix_k]; - B_array.get_host_pointer()[index] - = &B.get_device_pointer()[index * max_n * matrix_k]; - C_array.get_host_pointer()[index] - = &C.get_device_pointer()[index * max_n - * max_m]; // test atom add - BlasConnector::gemm( - 'N', - 'T', - m.get_host_pointer()[index], - n.get_host_pointer()[index], - matrix_k, - 1.0, - &A.get_host_pointer()[index * max_m * matrix_k], - matrix_k, - &B.get_host_pointer()[index * max_n * matrix_k], - matrix_k, - 1.0, - &cpu_result[index * max_m * max_n], - n.get_host_pointer()[index]); - index++; - } - } - } - - m.copy_host_to_device_sync(); - n.copy_host_to_device_sync(); - k.copy_host_to_device_sync(); - - lda.copy_host_to_device_sync(); - ldb.copy_host_to_device_sync(); - ldc.copy_host_to_device_sync(); - - A.copy_host_to_device_sync(); - B.copy_host_to_device_sync(); - A_array.copy_host_to_device_sync(); - B_array.copy_host_to_device_sync(); - C_array.copy_host_to_device_sync(); - - cudaStream_t temp_stream; - checkCuda(cudaStreamCreate(&temp_stream)); - - float fastest_time = 1000000; - fastest_algo = vbatched_gemm_impl; - - int* d_m = m.get_device_pointer(); - int* d_n = n.get_device_pointer(); - int* d_k = k.get_device_pointer(); - - double** d_global_A_array = A_array.get_device_pointer(); - double** d_global_B_array = B_array.get_device_pointer(); - double** d_global_C_array = C_array.get_device_pointer(); - - double* h_global_C = C.get_host_pointer(); - double* d_global_C = C.get_device_pointer(); - - int* d_global_lda = lda.get_device_pointer(); - int* d_global_ldb = ldb.get_device_pointer(); - int* d_global_ldc = ldc.get_device_pointer(); - -/* - * Please do not manually modify the code in the following file; - * it should simply be generated through a loop using a short Python program. - */ -#include "code_gen.cpp" - checkCuda(cudaStreamDestroy(temp_stream)); - std::cout << " gemm_algo_selector::Fastest time: " << fastest_time << " ms" - << std::endl; - // fastest_algo = vbatched_gemm_impl; - delete[] cpu_result; -} \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cuh index ea4e42e521..24e8ba91e1 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/vbatch_matrix_mul.cuh @@ -1,13 +1,333 @@ -#ifndef VBATCH_MATRIX_MUL_H -#define VBATCH_MATRIX_MUL_H +#ifndef VBATCH_MATRIX_MUL_CUH +#define VBATCH_MATRIX_MUL_CUH #include // for assert #include #include // for CUDA_VERSION #include #include // for fprintf and stderr +#include "cuda_tools.cuh" #include #include "module_cell/unitcell.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +#include + + +#define sA(i, j) sA[(j)*slda + (i)] +#define sB(i, j) sB[(j)*sldb + (i)] +#define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] + +template +static __device__ void vbatched_gemm_device(int M, + int N, + int K, + T* __restrict__ A, + int LDA, + T* __restrict__ B, + int LDB, + T* __restrict__ C, + int LDC, + T* sA, + int slda, + T* sB, + int sldb, + T alpha) +{ + int idx = threadIdx.x; // thread's m dimension + int idy = threadIdx.y; // thread's n dimension + + int idt = DIM_X * idy + idx; // thread's global number + + int idxA = idt % DIM_XA; // idx within A + int idyA = idt / DIM_XA; // idy within A + + int idxB = idt % DIM_XB; // idx within B + int idyB = idt / DIM_XB; // idy within B + + int blx = blockIdx.x; // block's m dimension + int bly = blockIdx.y; // block's n dimension + + // Registers for the innermost loop + T rC[THR_N][THR_M]; + T rA[THR_M]; + T rB[THR_N]; + + // Registers for the dev->shmem copy + T ra[BLK_M / DIM_YA][BLK_K / DIM_XA]; + T rb[BLK_N / DIM_YB][BLK_K / DIM_XB]; + + // bound is the correction to offs_d in order to not get out of memory bound + // so bound could be negative value since offs_d could be out of bound + T* offs_dA = A + blx * BLK_M * LDA + idyA * LDA + idxA; + int boundA + = (LDA * (M - 1) + K) - (blx * BLK_M * LDA + idyA * LDA + idxA) - 1; + + T* offs_dB = B + bly * BLK_N * LDB + idyB * LDB + idxB; + int boundB + = (LDB * (N - 1) + K) - (bly * BLK_N * LDB + idyB * LDB + idxB) - 1; + + int m, n, k, kk; + +// Zero C +#pragma unroll + for (n = 0; n < THR_N; n++) + { +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] = 0.0; + } + } + +// Load A dev->shmem +#pragma unroll + for (n = 0; n < BLK_M; n += DIM_YA) + { +#pragma unroll + for (m = 0; m < BLK_K; m += DIM_XA) + { + sA(n + idyA, m + idxA) = fetch(A, m, n, boundA); + } + } + +#pragma unroll + for (n = 0; n < BLK_N; n += DIM_YB) + { +#pragma unroll + for (m = 0; m < BLK_K; m += DIM_XB) + { + sB(m + idxB, n + idyB) = fetch(B, m, n, boundB); + } + } + + __syncthreads(); + + for (kk = 0; kk < K - BLK_K; kk += BLK_K) + { + offs_dA += BLK_K; + boundA -= BLK_K; + + offs_dB += BLK_K; + boundB -= BLK_K; + +// Load A dev->regs +#pragma unroll + for (n = 0; n < BLK_M / DIM_YA; n++) + { +#pragma unroll + for (m = 0; m < BLK_K / DIM_XA; m++) + { + ra[n][m] = fetch(A, m * DIM_XA, n * DIM_YA, boundA); + } + } + +// Load B dev->regs +#pragma unroll + for (n = 0; n < BLK_N / DIM_YB; n++) + { +#pragma unroll + for (m = 0; m < BLK_K / DIM_XB; m++) + { + rb[n][m] = fetch(B, m * DIM_XB, n * DIM_YB, boundB); + } + } + +// Multiply +#pragma unroll + for (k = 0; k < BLK_K; k++) + { +// Load A shmem->regs +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rA[m] = sA(m * DIM_X + idx, k); + } + +// Load B shmem->regs +#pragma unroll + for (n = 0; n < THR_N; n++) + { + rB[n] = sB(k, n * DIM_Y + idy); + } + +// Compute +#pragma unroll + for (n = 0; n < THR_N; n++) + { +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m] * rB[n]; + } + } + } + + __syncthreads(); + +// Load A regs->shmem +#pragma unroll + for (n = 0; n < BLK_M / DIM_YA; n++) + { +#pragma unroll + for (m = 0; m < BLK_K / DIM_XA; m++) + { + sA(n * DIM_YA + idyA, m * DIM_XA + idxA) = ra[n][m]; + } + } + +// Load B regs->shmem +#pragma unroll + for (n = 0; n < BLK_N / DIM_YB; n++) + { +#pragma unroll + for (m = 0; m < BLK_K / DIM_XB; m++) + { + sB(m * DIM_XB + idxB, n * DIM_YB + idyB) = rb[n][m]; + } + } + __syncthreads(); + } + + // Multiply last full (BLK_K) or partial block of + // columns of op(A) and rows of op(B). + // It's okay that m,n exceed matrix bounds as all work is in registers + // or shared memory, and out-of-bounds rC[n][m] will not be saved later. + kk = K - kk; +#pragma unroll + for (k = 0; k < kk; k++) + { +// Load A shmem->regs +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rA[m] = sA(m * DIM_X + idx, k); + } + +// Load B shmem->regs +#pragma unroll + for (n = 0; n < THR_N; n++) + { + rB[n] = sB(k, n * DIM_Y + idy); + } + +// Compute +#pragma unroll + for (n = 0; n < THR_N; n++) + { +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m] * rB[n]; + } + } + } + +// Store C regs->dev +#pragma unroll + for (n = 0; n < THR_N; n++) + { + int coord_dCn = bly * BLK_N + n * DIM_Y + idy; +#pragma unroll + for (m = 0; m < THR_M; m++) + { + int coord_dCm = blx * BLK_M + m * DIM_X + idx; + if (coord_dCm < M && coord_dCn < N) + { + int offsC = coord_dCn * LDC + coord_dCm; + + atomicAdd(C + offsC, rC[n][m] * alpha); + } + } + } +} + +/******************************************************************************/ +template +static __global__ void vbatched_gemm_kernel(int* M, + int* N, + int* K, + T** global_A_array, + int* global_lda, + T** global_B_array, + int* global_ldb, + T** global_C_array, + int* global_ldc, + T* alpha) +{ + extern __shared__ __align__(sizeof(T)) unsigned char smem[]; + T* shared_mem = reinterpret_cast(smem); + + int batchid = blockIdx.z; + int local_M = (int)M[batchid]; + int local_N = (int)N[batchid]; + int local_K = (int)K[batchid]; + + if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) + return; + if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) + return; + + int shared_lda = BLK_M + 1; + int shared_ldb = BLK_K + 1; + T* shared_A = (T*)shared_mem; + T* shared_B = shared_A + shared_lda * BLK_K; + double alpha_tmp = 1.0; + if (alpha != nullptr) + { + alpha_tmp = alpha[batchid]; + } + vbatched_gemm_device(local_M, + local_N, + local_K, + global_A_array[batchid], + (int)global_lda[batchid], + global_B_array[batchid], + (int)global_ldb[batchid], + global_C_array[batchid], + (int)global_ldc[batchid], + shared_A, + shared_lda, + shared_B, + shared_ldb, + alpha_tmp); +} + +static inline int ceildiv(int x, int y) +{ + return (x + y - 1) / y; +} + /** * Performs a batched matrix multiplication using the vbatched_gemm_impl * function. @@ -81,35 +401,198 @@ template void vbatched_gemm_impl(int max_m, - int max_n, - int* m, - int* n, - int* k, - T** global_A_array, - int* global_lda, - T** global_B_array, - int* global_ldb, - T** global_C_array, - int* global_ldc, - int batchCount, - cudaStream_t stream, - T* alpha = nullptr); - -typedef std::function - matrix_multiple_func_type; - -void gemm_algo_selector(int k, matrix_multiple_func_type& func,const UnitCell& ucell); -#endif // VBATCH_MATRIX_MUL_H \ No newline at end of file + int max_n, + int* m, + int* n, + int* k, + T** global_A_array, + int* global_lda, + T** global_B_array, + int* global_ldb, + T** global_C_array, + int* global_ldc, + int batchCount, + cudaStream_t stream, + T* alpha = nullptr) +{ + // The positions of A and B have been swapped here. + // This is because the original code is for column-major matrices. + // We use row-major matrices, so we need to swap A and B. + // The vbatched_gemm_impl is for C = trans(A) * B + C, but we need trans(C). + // Which means: trans(C) = trans(trans(A)*B + C) = trans(B) * A + trans(C) + // Then, ldc should be N, lda and ldb should be K + + size_t shared_mem_size = 0; + shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); + shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); + dim3 dimBlock(DIM_X, DIM_Y); + const int max_batch_count = 32768; + const int loop_num = batchCount / max_batch_count; + const int remain_num = batchCount % max_batch_count; + + for (int i = 0; i < loop_num; ++i) + { + dim3 dimGrid(ceildiv(max_n, BLK_M), + ceildiv(max_m, BLK_N), + max_batch_count); + T* alpha_tmp = nullptr; + if (alpha != nullptr) + { + alpha_tmp = alpha + i * max_batch_count; + } + + vbatched_gemm_kernel + <<>>( + n + i * max_batch_count, + m + i * max_batch_count, + k + i * max_batch_count, + global_B_array + i * max_batch_count, + global_ldb + i * max_batch_count, + global_A_array + i * max_batch_count, + global_lda + i * max_batch_count, + global_C_array + i * max_batch_count, + global_ldc + i * max_batch_count, + alpha_tmp); + checkCudaLastError(); + } + if (remain_num > 0) + { + dim3 dimGrid(ceildiv(max_n, BLK_M), ceildiv(max_m, BLK_N), remain_num); + T* alpha_tmp = nullptr; + if (alpha != nullptr) + { + alpha_tmp = alpha + loop_num * max_batch_count; + } + vbatched_gemm_kernel + <<>>( + n + loop_num * max_batch_count, + m + loop_num * max_batch_count, + k + loop_num * max_batch_count, + global_B_array + loop_num * max_batch_count, + global_ldb + loop_num * max_batch_count, + global_A_array + loop_num * max_batch_count, + global_lda + loop_num * max_batch_count, + global_C_array + loop_num * max_batch_count, + global_ldc + loop_num * max_batch_count, + alpha_tmp); + checkCudaLastError(); + } +} + +template +void gemm_time_measure(int max_m, + int max_n, + int* m, + int* n, + int* k, + T** global_A_array, + int* global_lda, + T** global_B_array, + int* global_ldb, + T** global_C_array, + int* global_ldc, + int batchCount, + cudaStream_t stream, + float& fast_time, + matrix_multiple_func_type& fastest_algo, + double* cpu_result, + double* h_global_C, + double* d_global_C) +{ + cudaEvent_t start, stop; + checkCuda( + cudaMemset(d_global_C, 0, batchCount * max_m * max_n * sizeof(double))); + checkCuda(cudaEventCreate(&start)); + checkCuda(cudaEventCreate(&stop)); + checkCuda(cudaEventRecord(start, stream)); + vbatched_gemm_impl(max_m, + max_n, + m, + n, + k, + global_A_array, + global_lda, + global_B_array, + global_ldb, + global_C_array, + global_ldc, + batchCount, + stream); + checkCuda(cudaEventRecord(stop, stream)); + cudaError_t cuda_status = cudaGetLastError(); + checkCuda(cudaStreamSynchronize(stream)); + float milliseconds = 0; + checkCuda(cudaEventElapsedTime(&milliseconds, start, stop)); + + // WARNING !!!!! Here we assume that all m and n are the same + checkCuda(cudaMemcpy(h_global_C, + d_global_C, + batchCount * max_m * max_n * sizeof(double), + cudaMemcpyDeviceToHost)); + bool check_result = true; + for (int i = 0; i < batchCount * max_m * max_n; ++i) + { + if (abs(cpu_result[i] - h_global_C[i]) > 0.001) + { + check_result = false; + break; + } + } + if (milliseconds < fast_time && cuda_status == cudaSuccess && check_result) + { + fast_time = milliseconds; + fastest_algo = vbatched_gemm_impl; +#ifdef __DEBUG + std::cout << "found! fastest time: " << fast_time << std::endl; + std::cout << DIM_X << "," << DIM_Y << "," << BLK_M << "," << BLK_N + << "," << BLK_K << "," << DIM_XA << "," << DIM_YA << "," + << DIM_XB << "," << DIM_YB << std::endl; +#endif + } +} +#endif // VBATCH_MATRIX_MUL_CUH \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_hcontainer/base_matrix.cpp b/source/module_hamilt_lcao/module_hcontainer/base_matrix.cpp index ff02d7744a..a7531a1245 100644 --- a/source/module_hamilt_lcao/module_hcontainer/base_matrix.cpp +++ b/source/module_hamilt_lcao/module_hcontainer/base_matrix.cpp @@ -109,8 +109,11 @@ assert(this->value_begin != nullptr); template void BaseMatrix::add_array(T* array) { - // if allocated, save data from array into matrix - // if whole matrix and 2d-block format, save data from array into matrix either +#ifdef __DEBUG +assert(this->value_begin != nullptr); +#endif + // if allocated, add data from array into matrix + // if whole matrix and 2d-block format, add data from array into matrix either for (int i = 0; i < nrow_local * ncol_local; ++i) { value_begin[i] += array[i]; @@ -120,6 +123,9 @@ void BaseMatrix::add_array(T* array) template void BaseMatrix::add_element(int mu, int nu, const T& value) { +#ifdef __DEBUG +assert(this->value_begin != nullptr); +#endif int index = mu * this->ncol_local + nu; value_begin[index] += value; } @@ -127,6 +133,9 @@ void BaseMatrix::add_element(int mu, int nu, const T& value) template T& BaseMatrix::get_value(const size_t& i_row, const size_t& j_col) const { +#ifdef __DEBUG +assert(this->value_begin != nullptr); +#endif int index = i_row * this->ncol_local + j_col; return value_begin[index]; } @@ -145,6 +154,12 @@ BaseMatrix& BaseMatrix::operator=(const BaseMatrix& other) { this->nrow_local = other.nrow_local; this->ncol_local = other.ncol_local; + + if (this->allocated) + { + delete[] this->value_begin; + } + if (other.allocated) { this->value_begin = new T[nrow_local * ncol_local]; @@ -172,6 +187,12 @@ BaseMatrix& BaseMatrix::operator=(BaseMatrix&& other) noexcept { this->nrow_local = other.nrow_local; this->ncol_local = other.ncol_local; + + if (this->allocated) + { + delete[] this->value_begin; + } + this->value_begin = other.value_begin; this->allocated = other.allocated; if (other.allocated) diff --git a/source/module_hamilt_lcao/module_hcontainer/base_matrix.h b/source/module_hamilt_lcao/module_hcontainer/base_matrix.h index d866c6f67f..ef591546a1 100644 --- a/source/module_hamilt_lcao/module_hcontainer/base_matrix.h +++ b/source/module_hamilt_lcao/module_hcontainer/base_matrix.h @@ -39,11 +39,12 @@ class BaseMatrix void set_zero(); /** - * @brief save an array to the matrix + * @brief add an array to the matrix * - * @param array array to be saved + * @param array array to be added */ void add_array(T* array); + /** * @brief add a single element to the matrix * @@ -52,6 +53,7 @@ class BaseMatrix * @param value value to be added */ void add_element(int mu, int nu, const T& value); + // for inside matrix /** * @brief get value from a whole matrix @@ -61,6 +63,7 @@ class BaseMatrix * @return T& */ T& get_value(const size_t& i_row, const size_t& j_col) const; + /** * @brief get pointer of value from a submatrix */ diff --git a/source/module_hamilt_lcao/module_tddft/td_velocity.cpp b/source/module_hamilt_lcao/module_tddft/td_velocity.cpp index 5de3d35209..3056c3d414 100644 --- a/source/module_hamilt_lcao/module_tddft/td_velocity.cpp +++ b/source/module_hamilt_lcao/module_tddft/td_velocity.cpp @@ -1,24 +1,145 @@ -#include "module_base/timer.h" #include "td_velocity.h" +#include "module_base/timer.h" +#include "module_elecstate/potentials/H_TDDFT_pw.h" bool TD_Velocity::tddft_velocity = false; bool TD_Velocity::out_mat_R = false; +bool TD_Velocity::out_vecpot = false; +bool TD_Velocity::init_vecpot_file = false; + TD_Velocity* TD_Velocity::td_vel_op = nullptr; +int TD_Velocity::istep = -1; +int TD_Velocity::max_istep = -1; +std::vector> TD_Velocity::At_from_file; + TD_Velocity::TD_Velocity() { - return; + if (init_vecpot_file && istep == -1) + { + this->read_cart_At(); + } + return; } TD_Velocity::~TD_Velocity() { - this->destroy_HS_R_td_sparse(); + this->destroy_HS_R_td_sparse(); +} + +void TD_Velocity::output_cart_At(const std::string& out_dir) +{ + if (GlobalV::MY_RANK == 0) + { + std::string out_file; + // generate the output file name + out_file = out_dir + "At.dat"; + std::ofstream ofs; + // output title + if (istep == 0) + { + ofs.open(out_file.c_str(), std::ofstream::out); + ofs << std::left << std::setw(8) << "#istep" << std::setw(15) << "A_x" << std::setw(15) << "A_y" + << std::setw(15) << "A_z" << std::endl; + } + else + { + ofs.open(out_file.c_str(), std::ofstream::app); + } + // output the vector potential + ofs << std::left << std::setw(8) << istep; + // divide by 2.0 to get the atomic unit + for (int i = 0; i < 3; i++) + { + ofs << std::scientific << std::setprecision(4) << std::setw(15) << cart_At[i] / 2.0; + } + ofs << std::endl; + ofs.close(); + } + return; } +void TD_Velocity::cal_cart_At(const ModuleBase::Vector3& a0, + const ModuleBase::Vector3& a1, + const ModuleBase::Vector3& a2, + const ModuleBase::Vector3& At) +{ + istep++; + if (init_vecpot_file) + { + this->cart_At = At_from_file[istep > max_istep ? max_istep : istep]; + } + else + { + const double l_norm[3] = {a0.norm(), a1.norm(), a2.norm()}; + this->cart_At = a0 * At[0] / l_norm[0] + a1 * At[1] / l_norm[1] + a2 * At[2] / l_norm[2]; + } + // output the vector potential if needed + if (out_vecpot == true) + { + this->output_cart_At(GlobalV::global_out_dir); + } +} + +void TD_Velocity::read_cart_At(void) +{ + std::string in_file; + // generate the input file name + in_file = "At.dat"; + std::ifstream ifs(in_file.c_str()); + // check if the file is exist + if (!ifs) + { + ModuleBase::WARNING_QUIT("TD_Velocity::read_cart_At", "Cannot open Vector potential file!"); + } + std::string line; + std::vector str_vec; + // use tmp to skip the istep number + int tmp = 0; + while (std::getline(ifs, line)) + { + // A tmporary vector3 to store the data of this line + ModuleBase::Vector3 At; + if (line[0] == '#') + { + continue; + } + std::istringstream iss(line); + // skip the istep number + if (!(iss >> tmp)) + { + ModuleBase::WARNING_QUIT("TD_Velocity::read_cart_At", "Error reading istep!"); + } + // read the vector potential + double component = 0; + // Read three components + for (int i = 0; i < 3; i++) + { + if (!(iss >> component)) + { + ModuleBase::WARNING_QUIT("TD_Velocity::read_cart_At", + "Error reading component " + std::to_string(i + 1) + " for istep " + + std::to_string(tmp) + "!"); + } + At[i] = component; + } + // unit transform ,change unit into Ry/bohr/e*t_a.u. + At *= 2.0; + // add the tmporary vector3 to the vector potential vector + At_from_file.push_back(At); + } + // set the max_istep + max_istep = At_from_file.size() - 1; + ifs.close(); + + return; +} void TD_Velocity::destroy_HS_R_td_sparse(void) { - std::map, std::map>>> empty_HR_sparse_td_vel_up; - std::map, std::map>>> empty_HR_sparse_td_vel_down; - HR_sparse_td_vel[0].swap(empty_HR_sparse_td_vel_up); - HR_sparse_td_vel[1].swap(empty_HR_sparse_td_vel_down); -} \ No newline at end of file + std::map, std::map>>> + empty_HR_sparse_td_vel_up; + std::map, std::map>>> + empty_HR_sparse_td_vel_down; + HR_sparse_td_vel[0].swap(empty_HR_sparse_td_vel_up); + HR_sparse_td_vel[1].swap(empty_HR_sparse_td_vel_down); +} diff --git a/source/module_hamilt_lcao/module_tddft/td_velocity.h b/source/module_hamilt_lcao/module_tddft/td_velocity.h index 0d274228dd..ddfaaf87cf 100644 --- a/source/module_hamilt_lcao/module_tddft/td_velocity.h +++ b/source/module_hamilt_lcao/module_tddft/td_velocity.h @@ -1,17 +1,19 @@ #ifndef TD_VELOCITY_H #define TD_VELOCITY_H -#include - -#include "module_base/timer.h" #include "module_base/abfs-vector3_order.h" -//Class to store TDDFT velocity gague infos. +#include "module_base/timer.h" + +#include +// Class to store TDDFT velocity gague infos. class TD_Velocity { public: TD_Velocity(); ~TD_Velocity(); - /// @brief Judge if in tddft calculation or not + void init(); + + /// @brief Judge if in tddft calculation or not static bool tddft_velocity; /// @brief switch to control the output of HR @@ -20,15 +22,42 @@ class TD_Velocity /// @brief pointer to the only TD_Velocity object itself static TD_Velocity* td_vel_op; - //For TDDFT velocity gague, to fix the output of HR - std::map, std::map>>> HR_sparse_td_vel[2]; + /// @brief switch to control the output of At + static bool out_vecpot; + + /// @brief switch to control the source of At + static bool init_vecpot_file; + /// @brief Store the vector potential for tddft calculation + ModuleBase::Vector3 cart_At; + + /// @brief calculate the At in cartesian coordinate + void cal_cart_At(const ModuleBase::Vector3& a0, + const ModuleBase::Vector3& a1, + const ModuleBase::Vector3& a2, + const ModuleBase::Vector3& At); + + // For TDDFT velocity gague, to fix the output of HR + std::map, std::map>>> HR_sparse_td_vel[2]; private: + /// @brief read At from output file + void read_cart_At(void); + + /// @brief output cart_At to output file + void output_cart_At(const std::string& out_dir); + + /// @brief store isteps now + static int istep; + + /// @brief total steps of read in At + static int max_istep; + + /// @brief store the read in At_data + static std::vector> At_from_file; /// @brief destory HSR data stored void destroy_HS_R_td_sparse(void); - }; -#endif \ No newline at end of file +#endif diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu index 81a0e176c3..5a1cdeecac 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu @@ -7,6 +7,7 @@ #include #include #include +#include #define THREADS_PER_BLOCK 256 diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu index 01c9f18175..91ac04ebc0 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu @@ -4,6 +4,7 @@ #include #include #include +#include #include diff --git a/source/module_hsolver/CMakeLists.txt b/source/module_hsolver/CMakeLists.txt index c4e7410da5..3c0b35383e 100644 --- a/source/module_hsolver/CMakeLists.txt +++ b/source/module_hsolver/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND objects if(ENABLE_LCAO) list(APPEND objects hsolver_lcao.cpp - diago_blas.cpp + diago_scalapack.cpp ) if (USE_ELPA) list(APPEND objects diff --git a/source/module_hsolver/diago_blas.cpp b/source/module_hsolver/diago_scalapack.cpp similarity index 94% rename from source/module_hsolver/diago_blas.cpp rename to source/module_hsolver/diago_scalapack.cpp index 1b9183520f..3a1432dee7 100644 --- a/source/module_hsolver/diago_blas.cpp +++ b/source/module_hsolver/diago_scalapack.cpp @@ -5,7 +5,7 @@ // DATE : 2022-04-14 //===================== -#include "diago_blas.h" +#include "diago_scalapack.h" #include #include @@ -21,9 +21,9 @@ typedef hamilt::MatrixBlock> matcd; namespace hsolver { template<> - void DiagoBlas::diag(hamilt::Hamilt* phm_in, psi::Psi& psi, Real* eigenvalue_in) + void DiagoScalapack::diag(hamilt::Hamilt* phm_in, psi::Psi& psi, Real* eigenvalue_in) { - ModuleBase::TITLE("DiagoElpa", "diag"); + ModuleBase::TITLE("DiagoScalapack", "diag"); matd h_mat, s_mat; phm_in->matrix(h_mat, s_mat); assert(h_mat.col == s_mat.col && h_mat.row == s_mat.row && h_mat.desc == s_mat.desc); @@ -33,9 +33,9 @@ namespace hsolver BlasConnector::copy(GlobalV::NBANDS, eigen.data(), inc, eigenvalue_in, inc); } template<> - void DiagoBlas>::diag(hamilt::Hamilt>* phm_in, psi::Psi>& psi, Real* eigenvalue_in) + void DiagoScalapack>::diag(hamilt::Hamilt>* phm_in, psi::Psi>& psi, Real* eigenvalue_in) { - ModuleBase::TITLE("DiagoElpa", "diag"); + ModuleBase::TITLE("DiagoScalapack", "diag"); matcd h_mat, s_mat; phm_in->matrix(h_mat, s_mat); assert(h_mat.col == s_mat.col && h_mat.row == s_mat.row && h_mat.desc == s_mat.desc); @@ -46,7 +46,7 @@ namespace hsolver } template - std::pair> DiagoBlas::pdsygvx_once(const int* const desc, + std::pair> DiagoScalapack::pdsygvx_once(const int* const desc, const int ncol, const int nrow, const double *const h_mat, @@ -169,7 +169,7 @@ namespace hsolver + ModuleBase::GlobalFunc::TO_STRING(__LINE__)); } template - std::pair> DiagoBlas::pzhegvx_once(const int* const desc, + std::pair> DiagoScalapack::pzhegvx_once(const int* const desc, const int ncol, const int nrow, const std::complex *const h_mat, @@ -303,7 +303,7 @@ namespace hsolver + ModuleBase::GlobalFunc::TO_STRING(__LINE__)); } template - void DiagoBlas::pdsygvx_diag(const int* const desc, + void DiagoScalapack::pdsygvx_diag(const int* const desc, const int ncol, const int nrow, const double *const h_mat, @@ -321,7 +321,7 @@ namespace hsolver } template - void DiagoBlas ::pzhegvx_diag(const int* const desc, + void DiagoScalapack ::pzhegvx_diag(const int* const desc, const int ncol, const int nrow, const std::complex *const h_mat, @@ -339,7 +339,7 @@ namespace hsolver } template - void DiagoBlas::post_processing(const int info, const std::vector& vec) + void DiagoScalapack::post_processing(const int info, const std::vector& vec) { const std::string str_info = "info = " + ModuleBase::GlobalFunc::TO_STRING(info) + ".\n"; const std::string str_FILE diff --git a/source/module_hsolver/diago_blas.h b/source/module_hsolver/diago_scalapack.h similarity index 96% rename from source/module_hsolver/diago_blas.h rename to source/module_hsolver/diago_scalapack.h index b445f0cc12..92f557de75 100644 --- a/source/module_hsolver/diago_blas.h +++ b/source/module_hsolver/diago_scalapack.h @@ -5,8 +5,8 @@ // DATE : 2022-04-14 //===================== -#ifndef DIAGOBLAS_H -#define DIAGOBLAS_H +#ifndef DIAGO_SCALAPACK_H +#define DIAGO_SCALAPACK_H #include #include @@ -20,7 +20,7 @@ namespace hsolver { template - class DiagoBlas : public DiagH + class DiagoScalapack : public DiagH { private: using Real = typename GetTypeReal::type; diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 0c3a787607..eb573c7cb4 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -1,6 +1,6 @@ #include "hsolver_lcao.h" -#include "diago_blas.h" +#include "diago_scalapack.h" #include "diago_cg.h" #include #include @@ -49,7 +49,7 @@ void HSolverLCAO::solveTemplate(hamilt::Hamilt* pHamilt, } if (this->pdiagh == nullptr) { - this->pdiagh = new DiagoBlas(); + this->pdiagh = new DiagoScalapack(); this->pdiagh->method = this->method; } } diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index cb785e83bb..9889d33de4 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -84,13 +84,13 @@ if(ENABLE_LCAO) AddTest( TARGET HSolver_LCAO LIBS ${math_libs} ELPA::ELPA base genelpa psi device - SOURCES diago_lcao_test.cpp ../diago_elpa.cpp ../diago_blas.cpp + SOURCES diago_lcao_test.cpp ../diago_elpa.cpp ../diago_scalapack.cpp ) else() AddTest( TARGET HSolver_LCAO LIBS ${math_libs} base psi device - SOURCES diago_lcao_test.cpp ../diago_blas.cpp + SOURCES diago_lcao_test.cpp ../diago_scalapack.cpp ) endif() @@ -106,7 +106,7 @@ if (USE_CUDA) AddTest( TARGET HSolver_LCAO_cusolver LIBS ${math_libs} base psi device - SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_blas.cpp + SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_scalapack.cpp ../kernels/math_kernel_op.cpp ../kernels/dngvd_op.cpp ../kernels/cuda/diag_cusolver.cu diff --git a/source/module_hsolver/test/diago_lcao_cusolver_test.cpp b/source/module_hsolver/test/diago_lcao_cusolver_test.cpp index 3f29d823d4..8139c8e07e 100644 --- a/source/module_hsolver/test/diago_lcao_cusolver_test.cpp +++ b/source/module_hsolver/test/diago_lcao_cusolver_test.cpp @@ -1,7 +1,7 @@ #include #include "gtest/gtest.h" -#include "module_hsolver/diago_blas.h" +#include "module_hsolver/diago_scalapack.h" #include "module_hsolver/test/diago_elpa_utils.h" #include "mpi.h" #include "string.h" @@ -24,7 +24,7 @@ /** * Tested function: * - hsolver::DiagoElpa::diag (for ELPA) - * - hsolver::DiagoBlas::diag (for Scalapack) + * - hsolver::DiagoScalapack::diag (for Scalapack) * * The 2d block cyclic distribution of H/S matrix is done by * self-realized functions in module_hsolver/test/diago_elpa_utils.h @@ -76,7 +76,7 @@ class DiagoPrepare MPI_Comm_rank(MPI_COMM_WORLD, &myrank); if (ks_solver == "scalapack_gvx") - dh = new hsolver::DiagoBlas; + dh = new hsolver::DiagoScalapack; #ifdef __CUDA else if (ks_solver == "cusolver") dh = new hsolver::DiagoCusolver; diff --git a/source/module_hsolver/test/diago_lcao_test.cpp b/source/module_hsolver/test/diago_lcao_test.cpp index 487820a9d5..d9ac6429a4 100644 --- a/source/module_hsolver/test/diago_lcao_test.cpp +++ b/source/module_hsolver/test/diago_lcao_test.cpp @@ -1,5 +1,5 @@ #include "module_hsolver/test/diago_elpa_utils.h" -#include "module_hsolver/diago_blas.h" +#include "module_hsolver/diago_scalapack.h" #include "mpi.h" #include "string.h" #include "gtest/gtest.h" @@ -20,7 +20,7 @@ /** * Tested function: * - hsolver::DiagoElpa::diag (for ELPA) - * - hsolver::DiagoBlas::diag (for Scalapack) + * - hsolver::DiagoScalapack::diag (for Scalapack) * * The 2d block cyclic distribution of H/S matrix is done by * self-realized functions in module_hsolver/test/diago_elpa_utils.h @@ -60,7 +60,7 @@ template class DiagoPrepare MPI_Comm_rank(MPI_COMM_WORLD, &myrank); if (ks_solver == "scalapack_gvx") - dh = new hsolver::DiagoBlas; + dh = new hsolver::DiagoScalapack; #ifdef __ELPA else if(ks_solver == "genelpa") dh = new hsolver::DiagoElpa; diff --git a/source/module_io/input.cpp b/source/module_io/input.cpp index 19bbe5deb8..64fae37c81 100644 --- a/source/module_io/input.cpp +++ b/source/module_io/input.cpp @@ -338,6 +338,7 @@ void Input::Default(void) deepks_scf = 0; deepks_bandgap = 0; deepks_out_unittest = 0; + deepks_equiv = 0; out_pot = 0; out_wfc_pw = 0; @@ -469,6 +470,8 @@ void Input::Default(void) out_dipole = false; out_efield = false; out_current = false; + out_vecpot = false; + init_vecpot_file = false; td_print_eij = -1.0; td_edm = 0; @@ -1412,6 +1415,10 @@ bool Input::Read(const std::string& fn) { read_bool(ifs, deepks_scf); } + else if (strcmp("deepks_equiv", word) == 0) + { + read_bool(ifs, deepks_equiv); + } else if (strcmp("deepks_bandgap", word) == 0) // caoyu added 2020-11-24, mohan modified 2021-01-03 { read_bool(ifs, deepks_bandgap); @@ -1808,6 +1815,14 @@ bool Input::Read(const std::string& fn) { read_value(ifs, out_efield); } + else if (strcmp("out_vecpot", word) == 0) + { + read_value(ifs, out_vecpot); + } + else if (strcmp("init_vecpot_file", word) == 0) + { + read_value(ifs, init_vecpot_file); + } else if (strcmp("td_print_eij", word) == 0) { read_value(ifs, td_print_eij); @@ -3578,6 +3593,7 @@ void Input::Bcast() Parallel_Common::bcast_bool(deepks_bandgap); Parallel_Common::bcast_bool(deepks_out_unittest); Parallel_Common::bcast_string(deepks_model); + Parallel_Common::bcast_bool(deepks_equiv); Parallel_Common::bcast_int(out_pot); Parallel_Common::bcast_int(out_wfc_pw); @@ -3746,6 +3762,8 @@ void Input::Bcast() Parallel_Common::bcast_bool(out_dipole); Parallel_Common::bcast_bool(out_efield); Parallel_Common::bcast_bool(out_current); + Parallel_Common::bcast_bool(out_vecpot); + Parallel_Common::bcast_bool(init_vecpot_file); Parallel_Common::bcast_double(td_print_eij); Parallel_Common::bcast_int(td_edm); Parallel_Common::bcast_bool(test_skip_ewald); diff --git a/source/module_io/input.h b/source/module_io/input.h index fb1b79f3ef..3aad48b7ce 100644 --- a/source/module_io/input.h +++ b/source/module_io/input.h @@ -416,6 +416,8 @@ class Input bool out_dipole; // output the dipole or not bool out_efield; // output the efield or not bool out_current; //output the current or not + bool out_vecpot; // output the vector potential or not + bool init_vecpot_file; // initialize the vector potential, though file or integral double td_print_eij; // threshold to output Eij elements int td_edm; //0: new edm method 1: old edm method @@ -517,7 +519,7 @@ class Input // 2022-1-12 bool deepks_scf; //(need libnpy and libtorch) if set 1, a trained model would be needed to cal V_delta and F_delta bool deepks_bandgap; // for bandgap label. QO added 2021-12-15 - + bool deepks_equiv; bool deepks_out_unittest; // if set 1, prints intermediate quantities that shall be used for making unit test std::string deepks_model; // needed when deepks_scf=1 diff --git a/source/module_io/input_conv.cpp b/source/module_io/input_conv.cpp index 8f6aaf256a..b2685ed702 100644 --- a/source/module_io/input_conv.cpp +++ b/source/module_io/input_conv.cpp @@ -550,6 +550,8 @@ void Input_Conv::Convert(void) module_tddft::Evolve_elec::out_current = INPUT.out_current; module_tddft::Evolve_elec::td_print_eij = INPUT.td_print_eij; module_tddft::Evolve_elec::td_edm = INPUT.td_edm; + TD_Velocity::out_vecpot = INPUT.out_vecpot; + TD_Velocity::init_vecpot_file = INPUT.init_vecpot_file; read_td_efield(); #endif @@ -769,6 +771,12 @@ void Input_Conv::Convert(void) GlobalV::deepks_bandgap = INPUT.deepks_bandgap; // QO added for bandgap label 2021-12-15 GlobalV::deepks_out_unittest = INPUT.deepks_out_unittest; GlobalV::deepks_out_labels = INPUT.deepks_out_labels; + GlobalV::deepks_equiv = INPUT.deepks_equiv; + + if(GlobalV::deepks_equiv && GlobalV::deepks_bandgap) + { + ModuleBase::WARNING_QUIT("Input_conv", "deepks_equiv and deepks_bandgap cannot be used together"); + } if (GlobalV::deepks_out_unittest) { GlobalV::deepks_out_labels = 1; diff --git a/source/module_io/parameter_pool.cpp b/source/module_io/parameter_pool.cpp index c51eced6d7..f1f2a21ef3 100644 --- a/source/module_io/parameter_pool.cpp +++ b/source/module_io/parameter_pool.cpp @@ -1529,6 +1529,10 @@ void input_parameters_set(std::map input_parameters { INPUT.deepks_scf = *static_cast(input_parameters["deepks_scf"].get()); } + else if (input_parameters.count("deepks_equiv") != 0) + { + INPUT.deepks_equiv = *static_cast(input_parameters["deepks_equiv"].get()); + } else if (input_parameters.count("deepks_bandgap") != 0) { INPUT.deepks_bandgap = *static_cast(input_parameters["deepks_bandgap"].get()); diff --git a/source/module_io/test/for_testing_input_conv.h b/source/module_io/test/for_testing_input_conv.h index d4449bdbf1..02d8c59484 100644 --- a/source/module_io/test/for_testing_input_conv.h +++ b/source/module_io/test/for_testing_input_conv.h @@ -16,6 +16,7 @@ #include "module_hamilt_lcao/hamilt_lcaodft/local_orbital_charge.h" #include "module_hamilt_lcao/module_dftu/dftu.h" #include "module_hamilt_lcao/module_tddft/evolve_elec.h" +#include "module_hamilt_lcao/module_tddft/td_velocity.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" #include "module_hamilt_pw/hamilt_pwdft/wavefunc.h" @@ -41,6 +42,8 @@ std::vector module_tddft::Evolve_elec::td_vext_dire_case; bool module_tddft::Evolve_elec::out_dipole; bool module_tddft::Evolve_elec::out_efield; bool module_tddft::Evolve_elec::out_current; +bool TD_Velocity::out_vecpot; +bool TD_Velocity::init_vecpot_file; double module_tddft::Evolve_elec::td_print_eij; int module_tddft::Evolve_elec::td_edm; double elecstate::Gatefield::zgate = 0.5; diff --git a/source/module_io/test/input_test.cpp b/source/module_io/test/input_test.cpp index 547178db88..e92bc2e9ab 100644 --- a/source/module_io/test/input_test.cpp +++ b/source/module_io/test/input_test.cpp @@ -170,6 +170,7 @@ TEST_F(InputTest, Default) EXPECT_EQ(INPUT.out_dm1,0); EXPECT_EQ(INPUT.deepks_out_labels,0); EXPECT_EQ(INPUT.deepks_scf,0); + EXPECT_EQ(INPUT.deepks_equiv,0); EXPECT_EQ(INPUT.deepks_bandgap,0); EXPECT_EQ(INPUT.deepks_out_unittest,0); EXPECT_EQ(INPUT.out_pot,0); @@ -537,6 +538,7 @@ TEST_F(InputTest, Read) EXPECT_EQ(INPUT.out_dm1,0); EXPECT_EQ(INPUT.deepks_out_labels,0); EXPECT_EQ(INPUT.deepks_scf,0); + EXPECT_EQ(INPUT.deepks_equiv,0); EXPECT_EQ(INPUT.deepks_bandgap,0); EXPECT_EQ(INPUT.deepks_out_unittest,0); EXPECT_EQ(INPUT.out_pot,2); diff --git a/source/module_io/test/input_test_para.cpp b/source/module_io/test/input_test_para.cpp index 5a6490551a..8246e057bd 100644 --- a/source/module_io/test/input_test_para.cpp +++ b/source/module_io/test/input_test_para.cpp @@ -177,6 +177,7 @@ TEST_F(InputParaTest, Bcast) EXPECT_EQ(INPUT.out_dm1, 0); EXPECT_EQ(INPUT.deepks_out_labels, 0); EXPECT_EQ(INPUT.deepks_scf, 0); + EXPECT_EQ(INPUT.deepks_equiv, 0); EXPECT_EQ(INPUT.deepks_bandgap, 0); EXPECT_EQ(INPUT.deepks_out_unittest, 0); EXPECT_EQ(INPUT.out_pot, 0); diff --git a/source/module_io/write_HS_sparse.cpp b/source/module_io/write_HS_sparse.cpp index c7e7d5475a..2ca1d58049 100644 --- a/source/module_io/write_HS_sparse.cpp +++ b/source/module_io/write_HS_sparse.cpp @@ -849,4 +849,4 @@ template void ModuleIO::save_sparse>( const Parallel_Orbitals&, const std::string&, const int&, - const bool&); + const bool&); \ No newline at end of file diff --git a/source/module_io/write_input.cpp b/source/module_io/write_input.cpp index 91a04201ab..73f3ca345c 100644 --- a/source/module_io/write_input.cpp +++ b/source/module_io/write_input.cpp @@ -210,6 +210,7 @@ ModuleBase::GlobalFunc::OUTP(ofs, "out_bandgap", out_bandgap, "if true, print ou // for deepks ModuleBase::GlobalFunc::OUTP(ofs, "deepks_out_labels", deepks_out_labels, ">0 compute descriptor for deepks"); ModuleBase::GlobalFunc::OUTP(ofs, "deepks_scf", deepks_scf, ">0 add V_delta to Hamiltonian"); + ModuleBase::GlobalFunc::OUTP(ofs, "deepks_equiv", deepks_equiv, "whether to use equivariant version of DeePKS"); ModuleBase::GlobalFunc::OUTP(ofs, "deepks_bandgap", deepks_bandgap, ">0 for bandgap label"); ModuleBase::GlobalFunc::OUTP(ofs, "deepks_out_unittest", diff --git a/source/version.h b/source/version.h index 087959a258..3ada3d837e 100644 --- a/source/version.h +++ b/source/version.h @@ -1,3 +1,3 @@ #ifndef VERSION -#define VERSION "v3.6.3" +#define VERSION "v3.6.4" #endif