Skip to content

Commit

Permalink
Refactor ModuleIO::read_cube_core()
Browse files Browse the repository at this point in the history
  • Loading branch information
PeizeLin committed Sep 15, 2024
1 parent 8dabb3f commit b8974ec
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 100 deletions.
48 changes: 26 additions & 22 deletions source/module_io/cube_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,6 @@ extern bool read_cube(
int& prenspin,
const bool warning_flag = true);

extern void read_cube_core(
std::ifstream &ifs,
#ifdef __MPI
const Parallel_Grid*const Pgrid,
#endif
const int my_rank,
const std::string esolver_type,
const int rank_in_stogroup,
#ifdef __MPI
#else
const int is,
std::ofstream& ofs_running,
#endif
double*const data,
const int nx,
const int ny,
const int nz,
const int nx_read,
const int ny_read,
const int nz_read);

extern void write_cube(
#ifdef __MPI
const int bz,
Expand All @@ -69,6 +48,31 @@ extern void write_cube(
const int precision = 11,
const int out_fermi = 1); // mohan add 2007-10-17


extern void read_cube_core_match(
std::ifstream &ifs,
#ifdef __MPI
const Parallel_Grid*const Pgrid,
const bool flag_read_rank,
#endif
double*const data,
const int nxy,
const int nz);

extern void read_cube_core_mismatch(
std::ifstream &ifs,
#ifdef __MPI
const Parallel_Grid*const Pgrid,
const bool flag_read_rank,
#endif
double*const data,
const int nx,
const int ny,
const int nz,
const int nx_read,
const int ny_read,
const int nz_read);

extern void write_cube_core(
std::ofstream &ofs_cube,
#ifdef __MPI
Expand Down Expand Up @@ -120,7 +124,7 @@ extern void write_cube_core(
const int& ny,
const int& nz,
#ifdef __MPI
double** data
std::vector<std::vector<double>> &data
#else
double* data
#endif
Expand Down
128 changes: 50 additions & 78 deletions source/module_io/read_cube.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,61 @@ bool ModuleIO::read_cube(
}
}

const bool flag_read_rank = (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0));
#ifdef __MPI
ModuleIO::read_cube_core(ifs, Pgrid, my_rank, esolver_type, rank_in_stogroup, data, nx, ny, nz, nx_read, ny_read, nz_read);
if(nx == nx_read && ny == ny_read && nz == nz_read)
ModuleIO::read_cube_core_match(ifs, Pgrid, flag_read_rank, data, nx*ny, nz);
else
ModuleIO::read_cube_core_mismatch(ifs, Pgrid, flag_read_rank, data, nx, ny, nz, nx_read, ny_read, nz_read);
#else
ModuleIO::read_cube_core(ifs, my_rank, esolver_type, rank_in_stogroup, is, ofs_running, data, nx, ny, nz, nx_read, ny_read, nz_read);
ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl;
if(nx == nx_read && ny == ny_read && nz == nz_read)
ModuleIO::read_cube_core_match(ifs, data, nx*ny, nz);
else
ModuleIO::read_cube_core_mismatch(ifs, data, nx, ny, nz, nx_read, ny_read, nz_read);
#endif

if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
ifs.close();
return true;
}

void ModuleIO::read_cube_core(
void ModuleIO::read_cube_core_match(
std::ifstream &ifs,
#ifdef __MPI
const Parallel_Grid*const Pgrid,
const bool flag_read_rank,
#endif
const int my_rank,
const std::string esolver_type,
const int rank_in_stogroup,
double*const data,
const int nxy,
const int nz)
{
#ifdef __MPI
if (flag_read_rank)
{
std::vector<std::vector<double>> read_rho(nz, std::vector<double>(nxy));
for (int ixy = 0; ixy < nxy; ixy++)
for (int iz = 0; iz < nz; iz++)
ifs >> read_rho[iz][ixy];
for (int iz = 0; iz < nz; iz++)
Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data);
}
else
{
std::vector<double> zpiece(nxy);
for (int iz = 0; iz < nz; iz++)
Pgrid->zpiece_to_all(zpiece.data(), iz, data);
}
#else
const int is,
std::ofstream& ofs_running,
for (int ixy = 0; ixy < nxy; ixy++)
for (int iz = 0; iz < nz; iz++)
ifs >> data[iz * nxy + ixy];
#endif
}

void ModuleIO::read_cube_core_mismatch(
std::ifstream &ifs,
#ifdef __MPI
const Parallel_Grid*const Pgrid,
const bool flag_read_rank,
#endif
double*const data,
const int nx,
Expand All @@ -144,83 +176,23 @@ void ModuleIO::read_cube_core(
const int ny_read,
const int nz_read)
{
const bool same = (nx == nx_read && ny == ny_read && nz == nz_read) ? true : false;

#ifdef __MPI
const int nxy = nx * ny;
double* zpiece = nullptr;
double** read_rho = nullptr;
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
if (flag_read_rank)
{
read_rho = new double*[nz];
std::vector<std::vector<double>> read_rho(nz, std::vector<double>(nxy));
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho);
for (int iz = 0; iz < nz; iz++)
{
read_rho[iz] = new double[nxy];
}
if (same)
{
for (int ix = 0; ix < nx; ix++)
{
for (int iy = 0; iy < ny; iy++)
{
for (int iz = 0; iz < nz; iz++)
{
ifs >> read_rho[iz][ix * ny + iy];
}
}
}
}
else
{
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho);
}
Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data);
}
else
{
zpiece = new double[nxy];
ModuleBase::GlobalFunc::ZEROS(zpiece, nxy);
}

for (int iz = 0; iz < nz; iz++)
{
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
{
zpiece = read_rho[iz];
}
Pgrid->zpiece_to_all(zpiece, iz, data);
} // iz

if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
{
std::vector<double> zpiece(nxy);
for (int iz = 0; iz < nz; iz++)
{
delete[] read_rho[iz];
}
delete[] read_rho;
}
else
{
delete[] zpiece;
Pgrid->zpiece_to_all(zpiece.data(), iz, data);
}
#else
ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl;
if (same)
{
for (int i = 0; i < nx; i++)
{
for (int j = 0; j < ny; j++)
{
for (int k = 0; k < nz; k++)
{
ifs >> data[k * nx * ny + i * ny + j];
}
}
}
}
else
{
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data);
}
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data);
#endif
}

Expand All @@ -232,7 +204,7 @@ void ModuleIO::trilinear_interpolate(std::ifstream& ifs,
const int& ny,
const int& nz,
#ifdef __MPI
double** data
std::vector<std::vector<double>> &data
#else
double* data
#endif
Expand Down

0 comments on commit b8974ec

Please sign in to comment.