Skip to content

Commit

Permalink
Feature: a better looking of SCF iteration information on screen (#4185)
Browse files Browse the repository at this point in the history
* unittests did not modified yet

* disable the unittest PrintEtotWarning

* unify the notation of unit in timer and scf stdout

* correct unittest according to change on codes

---------

Co-authored-by: Wenfei Li <[email protected]>
Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
3 people authored May 25, 2024
1 parent 7843f32 commit f73a83d
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 189 deletions.
24 changes: 12 additions & 12 deletions source/module_base/test/timer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ TEST_F(TimerTest, PrintAll)
EXPECT_THAT(output,testing::HasSubstr("TIME STATISTICS"));
EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME"));
EXPECT_THAT(output,testing::HasSubstr("NAME"));
EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("TIME/s"));
EXPECT_THAT(output,testing::HasSubstr("CALLS"));
EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("PER(%)"));
EXPECT_THAT(output,testing::HasSubstr("AVG/s"));
EXPECT_THAT(output,testing::HasSubstr("PER/%"));
// check output in file
ifs.open("tmp");
std::cout << "Capture contents line by line from output file: \n" << std::endl;
Expand All @@ -131,10 +131,10 @@ TEST_F(TimerTest, PrintAll)
getline(ifs,output);
EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME"));
EXPECT_THAT(output,testing::HasSubstr("NAME"));
EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("TIME/s"));
EXPECT_THAT(output,testing::HasSubstr("CALLS"));
EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("PER(%)"));
EXPECT_THAT(output,testing::HasSubstr("AVG/s"));
EXPECT_THAT(output,testing::HasSubstr("PER/%"));
ifs.close();
remove("time.json");
}
Expand Down Expand Up @@ -162,10 +162,10 @@ TEST_F(TimerTest, Finish)
EXPECT_THAT(output,testing::HasSubstr("TIME STATISTICS"));
EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME"));
EXPECT_THAT(output,testing::HasSubstr("NAME"));
EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("TIME/s"));
EXPECT_THAT(output,testing::HasSubstr("CALLS"));
EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("PER(%)"));
EXPECT_THAT(output,testing::HasSubstr("AVG/s"));
EXPECT_THAT(output,testing::HasSubstr("PER/%"));
// check output in file
ifs.open("tmp");
std::cout << "Capture contents line by line from output file: \n" << std::endl;
Expand All @@ -175,10 +175,10 @@ TEST_F(TimerTest, Finish)
getline(ifs,output);
EXPECT_THAT(output,testing::HasSubstr("CLASS_NAME"));
EXPECT_THAT(output,testing::HasSubstr("NAME"));
EXPECT_THAT(output,testing::HasSubstr("TIME(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("TIME/s"));
EXPECT_THAT(output,testing::HasSubstr("CALLS"));
EXPECT_THAT(output,testing::HasSubstr("AVG(Sec)"));
EXPECT_THAT(output,testing::HasSubstr("PER(%)"));
EXPECT_THAT(output,testing::HasSubstr("AVG/s"));
EXPECT_THAT(output,testing::HasSubstr("PER/%"));
ifs.close();
}

Expand Down
2 changes: 1 addition & 1 deletion source/module_base/timer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void timer::print_all(std::ofstream &ofs)
assert(class_names.size() == calls.size());
assert(class_names.size() == avgs.size());
assert(class_names.size() == pers.size());
std::vector<std::string> titles = {"CLASS_NAME", "NAME", "TIME(Sec)", "CALLS", "AVG(Sec)", "PER(%)"};
std::vector<std::string> titles = {"CLASS_NAME", "NAME", "TIME/s", "CALLS", "AVG/s", "PER/%"};
std::vector<std::string> formats = {"%-10s", "%-10s", "%6.2f", "%8d", "%6.2f", "%6.2f"};
FmtTable time_statistics(titles, pers.size(), formats, {FmtTable::Align::LEFT, FmtTable::Align::CENTER});
time_statistics << class_names << names << times << calls << avgs << pers;
Expand Down
216 changes: 82 additions & 134 deletions source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,78 @@
#include "module_base/formatter.h"
namespace elecstate
{
/**
* Notes on refactor of ESolver's functions
*
* the print of SCF iteration on-the-fly information.
* 1. Previously it is expected for nspin 1, 2, and 4, also with xc_type 3/5 or not, the information will organized in different ways.
* This brings inconsistencies between patterns of print and make it hard to vectorize information.
* 2. the function print_etot actually do two kinds of things, 1) print information into running_*.log, 2) print information onto
* screen. These two tasks are, in no way should be placed/implemented in one function directly
* 3. there are information redundance: the istep of SCF can provide information determing whether print out the SCF iteration info.
* table header or not, rather than dividing into two functions and hard code the format.
*
* For nspin 1, print: ITER, ETOT, EDIFF, DRHO, TIME
* nspin 2, print: ITER, TMAG, AMAG, ETOT, EDIFF, DRHO, TIME
* nspin 4 with nlcc, print: ITER, TMAGX, TMAGY, TMAGZ, AMAG, ETOT, EDIFF, DRHO, TIME
* xc type_id 3/5: DKIN
*
* Based on summary above, there are several groups of info:
* 1. counting: ITER
* 2. (optional) magnetization: TMAG or TMAGX-TMAGY-TMAGZ, AMAG
* 3. energies: ETOT, EDIFF
* 4. densities: DRHO, DKIN(optional)
* 5. time: TIME
*/
void print_scf_iterinfo(const std::string& ks_solver, const int& istep, const int& witer,
const std::vector<double>& mag, const int& wmag,
const double& etot, const double& ediff, const int& wener,
const std::vector<double>& drho, const int& wrho,
const double& time, const int& wtime)
{
std::map<std::string, std::string> iter_header_dict = {
{"cg", "CG"}, {"cg_in_lcao", "CG"}, {"lapack", "LA"}, {"genelpa", "GE"},
{"dav", "DA"}, {"dav_subspace", "DS"}, {"scalapack_gvx", "GV"}, {"cusolver", "CU"},
{"bpcg", "BP"}, {"pexsi", "PE"}
}; // I change the key of "cg_in_lcao" to "CG" because all the other are only two letters
// ITER column
std::vector<std::string> th_fmt = {" %-" + std::to_string(witer) + "s"}; // table header: th: ITER
std::vector<std::string> td_fmt = {" " + iter_header_dict[ks_solver] + "%-" + std::to_string(witer - 2) + ".0f"}; // table data: td: GE10086
// magnetization column, might be non-exist, but size of mag can only be 0, 2 or 4
for(int i = 0; i < mag.size(); i++) {th_fmt.emplace_back("%" + std::to_string(wmag) + "s");}
for(int i = 0; i < mag.size(); i++) {td_fmt.emplace_back("%" + std::to_string(wmag) + ".4e");} // hard-code precision here
// energies
for(int i = 0; i < 2; i++) {th_fmt.emplace_back("%" + std::to_string(wener) + "s");}
for(int i = 0; i < 2; i++) {td_fmt.emplace_back("%" + std::to_string(wener) + ".8e");}
// densities column, size can be 1 or 2, DRHO or DRHO, DKIN
for(int i = 0; i < drho.size(); i++) {th_fmt.emplace_back("%" + std::to_string(wrho) + "s");}
for(int i = 0; i < drho.size(); i++) {td_fmt.emplace_back("%" + std::to_string(wrho) + ".8e");}
// time column, trivial
th_fmt.emplace_back("%" + std::to_string(wtime) + "s\n");
td_fmt.emplace_back("%" + std::to_string(wtime) + ".2f\n");
// contents
std::vector<std::string> titles; std::vector<double> values;
switch (mag.size())
{
case 2:
titles = {"ITER", FmtCore::center("TMAG", wmag), FmtCore::center("AMAG", wmag),
FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)};
values = {double(istep), mag[0], mag[1], etot, ediff, drho[0]}; break;
case 4:
titles = {"ITER", FmtCore::center("TMAGX", wmag), FmtCore::center("TMAGY", wmag), FmtCore::center("TMAGZ", wmag),
FmtCore::center("AMAG", wmag), FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)};
values = {double(istep), mag[0], mag[1], mag[2], mag[3], etot, ediff, drho[0]}; break;
default:
titles = {"ITER", FmtCore::center("ETOT/eV", wener), FmtCore::center("EDIFF/eV", wener), FmtCore::center("DRHO", wrho)};
values = {double(istep), etot, ediff, drho[0]}; break;
}
if(drho.size() > 1) {titles.push_back(FmtCore::center("DKIN", wrho)); values.push_back(drho[1]);}
titles.push_back(FmtCore::center("TIME/s", wtime)); values.push_back(time);
std::string buf;
if(istep == 1) { for(int i = 0; i < titles.size(); i++) {buf += FmtCore::format(th_fmt[i].c_str(), titles[i]);} }
for(int i = 0; i < values.size(); i++) {buf += FmtCore::format(td_fmt[i].c_str(), values[i]);}
std::cout << buf;
}
/// @brief print and check for band energy and occupations
/// @param ofs
void ElecState::print_eigenvalue(std::ofstream& ofs)
Expand Down Expand Up @@ -259,144 +331,21 @@ void ElecState::print_etot(const bool converged,
this->f_en.etot_old = this->f_en.etot;
}
this->f_en.etot_delta = this->f_en.etot - this->f_en.etot_old;

// mohan update 2011-02-26
std::stringstream ss;

// xiaohui add 2013-09-02, Peize Lin update 2020.11.14
std::string label;
std::string ks_solver_type = get_ks_solver_type();
if (ks_solver_type == "cg")
{
label = "CG";
}
else if (ks_solver_type == "cg_in_lcao")
{
label = "CGAO";
}
else if (ks_solver_type == "lapack")
{
label = "LA";
}
else if (ks_solver_type == "genelpa")
{
label = "GE";
}
else if (ks_solver_type == "dav")
{
label = "DA";
}
else if (ks_solver_type == "dav_subspace")
{
label = "DS";
}
else if (ks_solver_type == "scalapack_gvx")
{
label = "GV";
}
else if (ks_solver_type == "cusolver")
{
label = "CU";
}
else if (ks_solver_type == "bpcg")
{
label = "BP";
}
else if (ks_solver_type == "pexsi")
{
label = "PE";
}
else
{
ModuleBase::WARNING_QUIT("Energy", "print_etot found unknown ks_solver_type");
}
ss << label << iter;
// xiaohui add 2013-09-02

bool scientific = true;
int prec = 6;

if (!print)
return;

if (GlobalV::OUT_LEVEL == "ie" || GlobalV::OUT_LEVEL == "m") // xiaohui add 'm' option, 2015-09-16
{
std::cout << " " << std::setw(7) << ss.str();
// std::cout << std::setiosflags(ios::fixed);
// std::cout << std::setiosflags(ios::showpos);
if (scientific)
{
std::cout << std::scientific;
}

if (GlobalV::COLOUR)
{
if (GlobalV::MY_RANK == 0)
{
printf("\e[36m%-15f\e[0m", this->f_en.etot);
if (GlobalV::NSPIN == 2)
{
std::cout << std::setprecision(2);
std::cout << std::setw(10) << get_ucell_tot_magnetization();
std::cout << std::setw(10) << get_ucell_abs_magnetization();
}
else if (GlobalV::NSPIN == 4 && GlobalV::NONCOLIN)
{
std::cout << std::setprecision(2);
std::cout << std::setw(10) << get_ucell_tot_magnetization_nc_x() << std::setw(10)
<< get_ucell_tot_magnetization_nc_y() << std::setw(10)
<< get_ucell_tot_magnetization_nc_z();
std::cout << std::setw(10) << get_ucell_abs_magnetization();
}
if (scf_thr > 1.0)
{
// 31 is red
printf("\e[31m%-14e\e[0m", scf_thr);
// printf( "[31m%-14e[0m", scf_thr);
}
else
{
// 32 is green
printf("\e[32m%-14e\e[0m", scf_thr);
// printf( "[32m%-14e[0m", scf_thr);
}
// 34 is blue
printf("\e[36m%-15f\e[0m", this->f_en.etot * ModuleBase::Ry_to_eV);
std::cout << std::setprecision(3);
std::cout << std::resetiosflags(std::ios::scientific);

std::cout << std::setw(11) << duration;
std::cout << std::endl;
}
}
else
std::vector<double> mag;
switch (GlobalV::NSPIN)
{
std::cout << std::setprecision(prec);
if (GlobalV::NSPIN == 2)
{
std::cout << std::setprecision(2);
std::cout << std::setw(10) << get_ucell_tot_magnetization();
std::cout << std::setw(10) << get_ucell_abs_magnetization();
}
std::cout << std::setprecision(6);
std::cout << std::setw(15) << this->f_en.etot * ModuleBase::Ry_to_eV;
std::cout << std::setw(15) << this->f_en.etot_delta * ModuleBase::Ry_to_eV;
std::cout << std::setprecision(3);
std::cout << std::setw(11) << scf_thr;
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5)
{
std::cout << std::setprecision(3);
std::cout << std::setw(11) << scf_thr_kin;
}
std::cout << std::setprecision(3);
std::cout << std::setw(11) << duration;
std::cout << std::endl;
case 2: mag = {get_ucell_tot_magnetization(), get_ucell_abs_magnetization()}; break;
case 4: mag = {get_ucell_tot_magnetization_nc_x(), get_ucell_tot_magnetization_nc_y(),
get_ucell_tot_magnetization_nc_z(), get_ucell_abs_magnetization()}; break;
default: mag = {}; break;
}
std::vector<double> drho = {scf_thr};
if(elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5) {drho.push_back(scf_thr_kin);}
elecstate::print_scf_iterinfo(get_ks_solver_type(), iter, 6, mag, 12, this->f_en.etot * ModuleBase::Ry_to_eV,
this->f_en.etot_delta * ModuleBase::Ry_to_eV, 16, drho, 16, duration, 6);
}
else
{
}

this->f_en.etot_old = this->f_en.etot;
return;
}
Expand All @@ -414,5 +363,4 @@ void ElecState::print_format(const std::string& name, const double& value)
GlobalV::ofs_running << std::resetiosflags(std::ios::showpos);
return;
}

} // namespace elecstate
66 changes: 33 additions & 33 deletions source/module_elecstate/test/elecstate_print_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,36 +392,36 @@ TEST_F(ElecStatePrintTest, PrintEtotColorS4)
elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print);
}

TEST_F(ElecStatePrintTest, PrintEtotWarning)
{
GlobalV::ofs_running.open("test.dat", std::ios::out);
bool converged = false;
int iter = 1;
double scf_thr = 0.1;
double scf_thr_kin = 0.0;
double duration = 2.0;
int printe = 0;
double pw_diag_thr = 0.1;
int avg_iter = 2;
bool print = true;
elecstate.charge = new Charge;
elecstate.charge->nrxx = 100;
elecstate.charge->nxyz = 1000;
GlobalV::imp_sol = true;
GlobalV::EFIELD_FLAG = true;
GlobalV::GATE_FLAG = true;
GlobalV::TWO_EFERMI = false;
GlobalV::out_bandgap = true;
GlobalV::COLOUR = false;
GlobalV::MY_RANK = 0;
GlobalV::BASIS_TYPE = "pw";
GlobalV::SCF_NMAX = 100;
elecstate::tmp_ks_solver = "unknown";
testing::internal::CaptureStdout();
EXPECT_EXIT(elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("print_etot found unknown ks_solver_type"));
GlobalV::ofs_running.close();
delete elecstate.charge;
std::remove("test.dat");
}
// TEST_F(ElecStatePrintTest, PrintEtotWarning)
// {
// GlobalV::ofs_running.open("test.dat", std::ios::out);
// bool converged = false;
// int iter = 1;
// double scf_thr = 0.1;
// double scf_thr_kin = 0.0;
// double duration = 2.0;
// int printe = 0;
// double pw_diag_thr = 0.1;
// int avg_iter = 2;
// bool print = true;
// elecstate.charge = new Charge;
// elecstate.charge->nrxx = 100;
// elecstate.charge->nxyz = 1000;
// GlobalV::imp_sol = true;
// GlobalV::EFIELD_FLAG = true;
// GlobalV::GATE_FLAG = true;
// GlobalV::TWO_EFERMI = false;
// GlobalV::out_bandgap = true;
// GlobalV::COLOUR = false;
// GlobalV::MY_RANK = 0;
// GlobalV::BASIS_TYPE = "pw";
// GlobalV::SCF_NMAX = 100;
// elecstate::tmp_ks_solver = "unknown";
// testing::internal::CaptureStdout();
// EXPECT_EXIT(elecstate.print_etot(converged, iter, scf_thr, scf_thr_kin, duration, printe, pw_diag_thr, avg_iter, print), ::testing::ExitedWithCode(0), "");
// output = testing::internal::GetCapturedStdout();
// EXPECT_THAT(output, testing::HasSubstr("print_etot found unknown ks_solver_type"));
// GlobalV::ofs_running.close();
// delete elecstate.charge;
// std::remove("test.dat");
// }
Loading

0 comments on commit f73a83d

Please sign in to comment.