diff --git a/source/module_io/cube_io.h b/source/module_io/cube_io.h index 9034f1cb6e..b9c33dfcf4 100644 --- a/source/module_io/cube_io.h +++ b/source/module_io/cube_io.h @@ -28,27 +28,6 @@ extern bool read_cube( int& prenspin, const bool warning_flag = true); -extern void read_cube_core( - std::ifstream &ifs, -#ifdef __MPI - const Parallel_Grid*const Pgrid, -#endif - const int my_rank, - const std::string esolver_type, - const int rank_in_stogroup, -#ifdef __MPI -#else - const int is, - std::ofstream& ofs_running, -#endif - double*const data, - const int nx, - const int ny, - const int nz, - const int nx_read, - const int ny_read, - const int nz_read); - extern void write_cube( #ifdef __MPI const int bz, @@ -69,6 +48,31 @@ extern void write_cube( const int precision = 11, const int out_fermi = 1); // mohan add 2007-10-17 + +extern void read_cube_core_match( + std::ifstream &ifs, +#ifdef __MPI + const Parallel_Grid*const Pgrid, + const bool flag_read_rank, +#endif + double*const data, + const int nxy, + const int nz); + +extern void read_cube_core_mismatch( + std::ifstream &ifs, +#ifdef __MPI + const Parallel_Grid*const Pgrid, + const bool flag_read_rank, +#endif + double*const data, + const int nx, + const int ny, + const int nz, + const int nx_read, + const int ny_read, + const int nz_read); + extern void write_cube_core( std::ofstream &ofs_cube, #ifdef __MPI @@ -120,7 +124,7 @@ extern void write_cube_core( const int& ny, const int& nz, #ifdef __MPI - double** data + std::vector> &data #else double* data #endif diff --git a/source/module_io/read_cube.cpp b/source/module_io/read_cube.cpp index 33a9bd04d0..971d97e474 100644 --- a/source/module_io/read_cube.cpp +++ b/source/module_io/read_cube.cpp @@ -112,29 +112,61 @@ bool ModuleIO::read_cube( } } + const bool flag_read_rank = (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0)); #ifdef __MPI - ModuleIO::read_cube_core(ifs, Pgrid, my_rank, esolver_type, rank_in_stogroup, data, nx, ny, nz, nx_read, ny_read, nz_read); + if(nx == nx_read && ny == ny_read && nz == nz_read) + ModuleIO::read_cube_core_match(ifs, Pgrid, flag_read_rank, data, nx*ny, nz); + else + ModuleIO::read_cube_core_mismatch(ifs, Pgrid, flag_read_rank, data, nx, ny, nz, nx_read, ny_read, nz_read); #else - ModuleIO::read_cube_core(ifs, my_rank, esolver_type, rank_in_stogroup, is, ofs_running, data, nx, ny, nz, nx_read, ny_read, nz_read); + ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl; + if(nx == nx_read && ny == ny_read && nz == nz_read) + ModuleIO::read_cube_core_match(ifs, data, nx*ny, nz); + else + ModuleIO::read_cube_core_mismatch(ifs, data, nx, ny, nz, nx_read, ny_read, nz_read); #endif - if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0)) - ifs.close(); return true; } -void ModuleIO::read_cube_core( +void ModuleIO::read_cube_core_match( std::ifstream &ifs, #ifdef __MPI const Parallel_Grid*const Pgrid, + const bool flag_read_rank, #endif - const int my_rank, - const std::string esolver_type, - const int rank_in_stogroup, + double*const data, + const int nxy, + const int nz) +{ #ifdef __MPI + if (flag_read_rank) + { + std::vector> read_rho(nz, std::vector(nxy)); + for (int ixy = 0; ixy < nxy; ixy++) + for (int iz = 0; iz < nz; iz++) + ifs >> read_rho[iz][ixy]; + for (int iz = 0; iz < nz; iz++) + Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data); + } + else + { + std::vector zpiece(nxy); + for (int iz = 0; iz < nz; iz++) + Pgrid->zpiece_to_all(zpiece.data(), iz, data); + } #else - const int is, - std::ofstream& ofs_running, + for (int ixy = 0; ixy < nxy; ixy++) + for (int iz = 0; iz < nz; iz++) + ifs >> data[iz * nxy + ixy]; +#endif +} + +void ModuleIO::read_cube_core_mismatch( + std::ifstream &ifs, +#ifdef __MPI + const Parallel_Grid*const Pgrid, + const bool flag_read_rank, #endif double*const data, const int nx, @@ -144,83 +176,23 @@ void ModuleIO::read_cube_core( const int ny_read, const int nz_read) { - const bool same = (nx == nx_read && ny == ny_read && nz == nz_read) ? true : false; - #ifdef __MPI const int nxy = nx * ny; - double* zpiece = nullptr; - double** read_rho = nullptr; - if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0)) + if (flag_read_rank) { - read_rho = new double*[nz]; + std::vector> read_rho(nz, std::vector(nxy)); + ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho); for (int iz = 0; iz < nz; iz++) - { - read_rho[iz] = new double[nxy]; - } - if (same) - { - for (int ix = 0; ix < nx; ix++) - { - for (int iy = 0; iy < ny; iy++) - { - for (int iz = 0; iz < nz; iz++) - { - ifs >> read_rho[iz][ix * ny + iy]; - } - } - } - } - else - { - ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho); - } + Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data); } else { - zpiece = new double[nxy]; - ModuleBase::GlobalFunc::ZEROS(zpiece, nxy); - } - - for (int iz = 0; iz < nz; iz++) - { - if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0)) - { - zpiece = read_rho[iz]; - } - Pgrid->zpiece_to_all(zpiece, iz, data); - } // iz - - if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0)) - { + std::vector zpiece(nxy); for (int iz = 0; iz < nz; iz++) - { - delete[] read_rho[iz]; - } - delete[] read_rho; - } - else - { - delete[] zpiece; + Pgrid->zpiece_to_all(zpiece.data(), iz, data); } #else - ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl; - if (same) - { - for (int i = 0; i < nx; i++) - { - for (int j = 0; j < ny; j++) - { - for (int k = 0; k < nz; k++) - { - ifs >> data[k * nx * ny + i * ny + j]; - } - } - } - } - else - { - ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data); - } + ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data.data()); #endif } @@ -232,7 +204,7 @@ void ModuleIO::trilinear_interpolate(std::ifstream& ifs, const int& ny, const int& nz, #ifdef __MPI - double** data + std::vector> &data #else double* data #endif