Skip to content

Commit

Permalink
Refactor: Remove INPUT class (#4815)
Browse files Browse the repository at this point in the history
* Refactor: update md parameters to apply const param_in

* Tests: update unittests

* Tests: update tests

* Tests: update other tests

* [pre-commit.ci lite] apply automatic fixes

* refactor: remove all INPUT

* [pre-commit.ci lite] apply automatic fixes

* delete input.h

* [pre-commit.ci lite] apply automatic fixes

* fix compile failure

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: YuLiu98 <[email protected]>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent 6204641 commit 3845b76
Show file tree
Hide file tree
Showing 30 changed files with 448 additions and 238 deletions.
4 changes: 1 addition & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,7 @@ OBJS_XC=xc_functional.o\
xc_funct_corr_gga.o\
xc_funct_hcth.o\

OBJS_IO=input.o\
input_conv_tmp.o\
input_conv.o\
OBJS_IO=input_conv.o\
berryphase.o\
bessel_basis.o\
cal_test.o\
Expand Down
2 changes: 0 additions & 2 deletions source/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "module_esolver/esolver.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/cal_test.h"
#include "module_io/input.h"
#include "module_io/input_conv.h"
#include "module_io/para_json.h"
#include "module_io/print_info.h"
Expand Down Expand Up @@ -132,7 +131,6 @@ void Driver::reading()
read_input.write_parameters(PARAM, ss1.str());

// (*temp*) copy the variables from INPUT to each class
Input_Conv::tmp_convert();
Input_Conv::Convert();

// (4) define the 'DIAGONALIZATION' world in MPI
Expand Down
1 change: 0 additions & 1 deletion source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "module_cell/check_atomic_stru.h"
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"
#include "module_io/para_json.h"
#include "module_io/print_info.h"
Expand Down
20 changes: 12 additions & 8 deletions source/module_elecstate/potentials/H_TDDFT_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "module_base/timer.h"
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/input.h"
#include "module_io/input_conv.h"

namespace elecstate
Expand Down Expand Up @@ -80,7 +79,8 @@ void H_TDDFT_pw::cal_fixed_v(double *vl_pseudo)
ModuleBase::TITLE("H_TDDFT_pw", "cal_fixed_v");

//skip if velocity_gague
if(stype==1)return;
if(stype==1) {return;
}

// time evolve
H_TDDFT_pw::istep++;
Expand Down Expand Up @@ -233,7 +233,7 @@ int H_TDDFT_pw::check_ncut(int t_type)
}
return ncut;
}
void H_TDDFT_pw::update_At(void)
void H_TDDFT_pw::update_At()
{
std::cout << "calculate electric potential" << std::endl;
// time evolve
Expand Down Expand Up @@ -345,7 +345,8 @@ double H_TDDFT_pw::cal_v_time_Gauss(const bool last)

double gauss_t = (istep_int - t0 * ncut) * dt_int;
vext_time = cos(omega * gauss_t + phase) * exp(-gauss_t * gauss_t * 0.5 / (sigma * sigma)) * amp;
if(last)gauss_count++;
if(last) {gauss_count++;
}

return vext_time;
}
Expand Down Expand Up @@ -375,7 +376,8 @@ double H_TDDFT_pw::cal_v_time_trapezoid(const bool last)
}

vext_time = vext_time * amp * cos(omega * istep_int * dt_int + phase);
if(last)trape_count++;
if(last) {trape_count++;
}

return vext_time;
}
Expand All @@ -392,7 +394,8 @@ double H_TDDFT_pw::cal_v_time_trigonometric(const bool last)
const double timenow = istep_int * dt_int;

vext_time = amp * cos(omega1 * timenow + phase1) * sin(omega2 * timenow + phase2) * sin(omega2 * timenow + phase2);
if(last)trigo_count++;
if(last) {trigo_count++;
}

return vext_time;
}
Expand All @@ -402,10 +405,11 @@ double H_TDDFT_pw::cal_v_time_heaviside()
double t0 = *(heavi_t0.begin() + heavi_count);
double amp = *(heavi_amp.begin() + heavi_count);
double vext_time = 0.0;
if (istep < t0)
if (istep < t0) {
vext_time = amp;
else if (istep >= t0)
} else if (istep >= t0) {
vext_time = 0.0;
}
heavi_count++;

return vext_time;
Expand Down
1 change: 0 additions & 1 deletion source/module_elecstate/potentials/H_TDDFT_pw.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef H_TDDFT_PW_H
#define H_TDDFT_PW_H

#include "module_io/input.h"
#include "module_io/input_conv.h"
#include "pot_base.h"

Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "module_base/matrix.h"
#include "module_cell/unitcell.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"

namespace ModuleESolver
Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "module_base/global_variable.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"
namespace ModuleESolver
{
Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <iostream>

#include "module_base/timer.h"
#include "module_io/input.h"
#include "module_io/json_output/init_info.h"
#include "module_io/print_info.h"
#include "module_parameter/parameter.h"
Expand Down
6 changes: 2 additions & 4 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,13 @@ void ESolver_SDFT_PW::after_all_runners()
this->p_hamilt,
(hsolver::HSolverPW_SDFT*)phsol,
&stowf);
sto_elecond
.decide_nche(PARAM.inp.cond_dt, INPUT.cond_dtbatch, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto);
sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto);
sto_elecond.sKG(PARAM.inp.cond_smear,
PARAM.inp.cond_fwhm,
PARAM.inp.cond_wcut,
PARAM.inp.cond_dw,
PARAM.inp.cond_dt,
PARAM.inp.cond_nonlocal,
INPUT.cond_dtbatch,
PARAM.inp.npart_sto);
}
}
Expand Down Expand Up @@ -353,7 +351,7 @@ void ESolver_SDFT_PW::nscf()

const int iter = 1;

const double diag_thr = std::max(std::min(1e-5, 0.1 *PARAM.inp.scf_thr / std::max(1.0, GlobalV::nelec)), 1e-12);
const double diag_thr = std::max(std::min(1e-5, 0.1 * PARAM.inp.scf_thr / std::max(1.0, GlobalV::nelec)), 1e-12);

std::cout << " DIGA_THR : " << diag_thr << std::endl;

Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/print_funcs.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef PRINT_FUNCTIONS_H
#define PRINT_FUNCTIONS_H

#include "module_io/input.h"
#include "module_parameter/input_parameter.h"
#include "module_basis/module_pw/pw_basis_k.h"

namespace Print_functions
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_general/module_surchem/test/setcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_cell/unitcell.h"
#include "module_io/input.h"
#include "module_hamilt_pw/hamilt_pwdft/structure_factor.h"

namespace GlobalC
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_general/module_vdw/test/vdw_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "module_cell/unitcell.h"
#include "module_cell/setup_nonlocal.h"
#include "module_io/input.h"
#include "module_base/mathzone.h"
#include "module_base/vector3.h"
#include"gtest/gtest.h"
Expand Down
8 changes: 4 additions & 4 deletions source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/test/test_dftu.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class DFTUTest : public ::testing::Test
GlobalC::dftu.locale[iat][l][0][1].create(2 * l + 1, 2 * l + 1);
}
}
GlobalC::dftu.U = &U_test;
GlobalC::dftu.orbital_corr = &orbital_c_test;
GlobalC::dftu.U = {U_test};
GlobalC::dftu.orbital_corr = {orbital_c_test};

PARAM.input.onsite_radius = 1.0;
}
Expand Down Expand Up @@ -150,7 +150,7 @@ TEST_F(DFTUTest, constructHRd2d)
// test for nspin=1
GlobalV::NSPIN = 1;
std::vector<ModuleBase::Vector3<double>> kvec_d_in(1, ModuleBase::Vector3<double>(0.0, 0.0, 0.0));
hamilt::HS_Matrix_K<double> hsk(paraV, 1);
hamilt::HS_Matrix_K<double> hsk(paraV, true);
hsk.set_zero_hk();
Grid_Driver gd(0, 0, 0);
// check some input values
Expand Down Expand Up @@ -218,7 +218,7 @@ TEST_F(DFTUTest, constructHRd2cd)
// test for nspin=2
GlobalV::NSPIN = 2;
std::vector<ModuleBase::Vector3<double>> kvec_d_in(2, ModuleBase::Vector3<double>(0.0, 0.0, 0.0));
hamilt::HS_Matrix_K<std::complex<double>> hsk(paraV, 1);
hamilt::HS_Matrix_K<std::complex<double>> hsk(paraV, true);
hsk.set_zero_hk();
Grid_Driver gd(0, 0, 0);
EXPECT_EQ(LCAO_Orbitals::get_const_instance().Phi[0].getRcut(), 1.0);
Expand Down
5 changes: 3 additions & 2 deletions source/module_hamilt_lcao/module_dftu/dftu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "module_hamilt_lcao/hamilt_lcaodft/force_stress_arrays.h" // mohan add 2024-06-15

#include <string>
#include <vector>

//==========================================================
// CLASS :
Expand Down Expand Up @@ -47,9 +48,9 @@ class DFTU
void uramping_update(); // update U by uramping
bool u_converged(); // check if U is converged

double* U; // U (Hubbard parameter U)
std::vector<double> U = {}; // U (Hubbard parameter U)
std::vector<double> U0; // U0 (target Hubbard parameter U0)
int* orbital_corr; //
std::vector<int> orbital_corr = {}; //
double uramping; // increase U by uramping, default is -1.0
int omc; // occupation matrix control
int mixing_dftu; //whether to mix locale
Expand Down
50 changes: 32 additions & 18 deletions source/module_hamilt_lcao/module_dftu/dftu_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,42 @@ void DFTU::cal_VU_pot_mat_complex(const int spin, const bool newlocale, std::com

for (int it = 0; it < GlobalC::ucell.ntype; ++it)
{
if (INPUT.orbital_corr[it] == -1)
if (PARAM.inp.orbital_corr[it] == -1) {
continue;
}
for (int ia = 0; ia < GlobalC::ucell.atoms[it].na; ia++)
{
const int iat = GlobalC::ucell.itia2iat(it, ia);
for (int L = 0; L <= GlobalC::ucell.atoms[it].nwl; L++)
{
if (L != INPUT.orbital_corr[it])
if (L != PARAM.inp.orbital_corr[it]) {
continue;
}

for (int n = 0; n < GlobalC::ucell.atoms[it].l_nchi[L]; n++)
{
if (n != 0)
if (n != 0) {
continue;
}

for (int m1 = 0; m1 < 2 * L + 1; m1++)
{
for (int ipol1 = 0; ipol1 < GlobalV::NPOL; ipol1++)
{
const int mu = this->paraV->global2local_row(this->iatlnmipol2iwt[iat][L][n][m1][ipol1]);
if (mu < 0)
if (mu < 0) {
continue;
}

for (int m2 = 0; m2 < 2 * L + 1; m2++)
{
for (int ipol2 = 0; ipol2 < GlobalV::NPOL; ipol2++)
{
const int nu
= this->paraV->global2local_col(this->iatlnmipol2iwt[iat][L][n][m2][ipol2]);
if (nu < 0)
if (nu < 0) {
continue;
}

int m1_all = m1 + (2 * L + 1) * ipol1;
int m2_all = m2 + (2 * L + 1) * ipol2;
Expand All @@ -68,36 +73,41 @@ void DFTU::cal_VU_pot_mat_real(const int spin, const bool newlocale, double* VU)

for (int it = 0; it < GlobalC::ucell.ntype; ++it)
{
if (INPUT.orbital_corr[it] == -1)
if (PARAM.inp.orbital_corr[it] == -1) {
continue;
}
for (int ia = 0; ia < GlobalC::ucell.atoms[it].na; ia++)
{
const int iat = GlobalC::ucell.itia2iat(it, ia);
for (int L = 0; L <= GlobalC::ucell.atoms[it].nwl; L++)
{
if (L != INPUT.orbital_corr[it])
if (L != PARAM.inp.orbital_corr[it]) {
continue;
}

for (int n = 0; n < GlobalC::ucell.atoms[it].l_nchi[L]; n++)
{
if (n != 0)
if (n != 0) {
continue;
}

for (int m1 = 0; m1 < 2 * L + 1; m1++)
{
for (int ipol1 = 0; ipol1 < GlobalV::NPOL; ipol1++)
{
const int mu = this->paraV->global2local_row(this->iatlnmipol2iwt[iat][L][n][m1][ipol1]);
if (mu < 0)
if (mu < 0) {
continue;
}
for (int m2 = 0; m2 < 2 * L + 1; m2++)
{
for (int ipol2 = 0; ipol2 < GlobalV::NPOL; ipol2++)
{
const int nu
= this->paraV->global2local_col(this->iatlnmipol2iwt[iat][L][n][m2][ipol2]);
if (nu < 0)
if (nu < 0) {
continue;
}

int m1_all = m1 + (2 * L + 1) * ipol1;
int m2_all = m2 + (2 * L + 1) * ipol2;
Expand Down Expand Up @@ -145,37 +155,41 @@ double DFTU::get_onebody_eff_pot(const int T,
{
if (Yukawa)
{
if (m0 == m1)
if (m0 == m1) {
VU = (this->U_Yukawa[T][L][N] - this->J_Yukawa[T][L][N])
* (0.5 - this->locale[iat][L][N][spin](m0, m1));
else
} else {
VU = -(this->U_Yukawa[T][L][N] - this->J_Yukawa[T][L][N]) * this->locale[iat][L][N][spin](m0, m1);
}
}
else
{
if (m0 == m1)
if (m0 == m1) {
VU = (this->U[T]) * (0.5 - this->locale[iat][L][N][spin](m0, m1));
else
} else {
VU = -(this->U[T]) * this->locale[iat][L][N][spin](m0, m1);
}
}
}
else
{
if (Yukawa)
{
if (m0 == m1)
if (m0 == m1) {
VU = (this->U_Yukawa[T][L][N] - this->J_Yukawa[T][L][N])
* (0.5 - this->locale_save[iat][L][N][spin](m0, m1));
else
} else {
VU = -(this->U_Yukawa[T][L][N] - this->J_Yukawa[T][L][N])
* this->locale_save[iat][L][N][spin](m0, m1);
}
}
else
{
if (m0 == m1)
if (m0 == m1) {
VU = (this->U[T]) * (0.5 - this->locale_save[iat][L][N][spin](m0, m1));
else
} else {
VU = -(this->U[T]) * this->locale_save[iat][L][N][spin](m0, m1);
}
}
}

Expand Down
Loading

0 comments on commit 3845b76

Please sign in to comment.