Skip to content

Commit

Permalink
optimize td_current and fix timer
Browse files Browse the repository at this point in the history
  • Loading branch information
dzzz2001 committed Sep 26, 2024
1 parent e32341a commit 3f41ce4
Showing 1 changed file with 102 additions and 89 deletions.
191 changes: 102 additions & 89 deletions source/module_hamilt_lcao/module_tddft/td_current.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "module_hamilt_lcao/module_tddft/snap_psibeta_half_tddft.h"
#ifdef _OPENMP
#include <unordered_set>
#include <omp.h>
#endif

TD_current::TD_current(const UnitCell* ucell_in,
Expand Down Expand Up @@ -161,7 +162,7 @@ void TD_current::initialize_grad_term(Grid_Driver* GridD, const Parallel_Orbital
this->current_term[dir]->allocate(nullptr, true);
}

ModuleBase::timer::tick("EkineticNew", "initialize_HR");
ModuleBase::timer::tick("TD_current", "initialize_grad_term");
}

void TD_current::calculate_vcomm_r()
Expand All @@ -171,17 +172,8 @@ void TD_current::calculate_vcomm_r()

const Parallel_Orbitals* paraV = this->current_term[0]->get_atom_pair(0).get_paraV();
const int npol = this->ucell->get_npol();

// 1. calculate <psi|beta> for each pair of atoms
#ifdef _OPENMP
#pragma omp parallel
{
std::unordered_set<int> atom_row_list;
#pragma omp for
for (int iat0 = 0; iat0 < this->ucell->nat; iat0++)
{
atom_row_list.insert(iat0);
}
#endif
for (int iat0 = 0; iat0 < this->ucell->nat; iat0++)
{
auto tau0 = ucell->get_tau(iat0);
Expand All @@ -196,99 +188,120 @@ void TD_current::calculate_vcomm_r()
nlm_tot[i].resize(4);
}

for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
#pragma omp parallel
{
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
const ModuleBase::Vector3<double>& tau1 = adjs.adjacent_tau[ad];
const Atom* atom1 = &ucell->atoms[T1];

auto all_indexes = paraV->get_indexes_row(iat1);
#ifdef _OPENMP
if(atom_row_list.find(iat1) == atom_row_list.end())
#pragma omp for schedule(dynamic)
for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
{
all_indexes.clear();
}
#endif
auto col_indexes = paraV->get_indexes_col(iat1);
// insert col_indexes into all_indexes to get universal set with no repeat elements
all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end());
std::sort(all_indexes.begin(), all_indexes.end());
all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end());
for (int iw1l = 0; iw1l < all_indexes.size(); iw1l += npol)
{
const int iw1 = all_indexes[iw1l] / npol;
std::vector<std::vector<std::complex<double>>> nlm;
// nlm is a vector of vectors, but size of outer vector is only 1 when out_current is false
// and size of outer vector is 4 when out_current is true (3 for <psi|r_i * exp(-iAr)|beta>, 1 for
// <psi|exp(-iAr)|beta>) inner loop : all projectors (L0,M0)

// snap_psibeta_half_tddft() are used to calculate <psi|exp(-iAr)|beta>
// and <psi|rexp(-iAr)|beta> as well if current are needed

module_tddft::snap_psibeta_half_tddft(orb_,
this->ucell->infoNL,
nlm,
tau1 * this->ucell->lat0,
T1,
atom1->iw2l[iw1],
atom1->iw2m[iw1],
atom1->iw2n[iw1],
tau0 * this->ucell->lat0,
T0,
this->cart_At,
true);
for (int dir = 0; dir < 4; dir++)
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
const ModuleBase::Vector3<double>& tau1 = adjs.adjacent_tau[ad];
const Atom* atom1 = &ucell->atoms[T1];
auto all_indexes = paraV->get_indexes_row(iat1);
auto col_indexes = paraV->get_indexes_col(iat1);
// insert col_indexes into all_indexes to get universal set with no repeat elements
all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end());
std::sort(all_indexes.begin(), all_indexes.end());
all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end());
for (int iw1l = 0; iw1l < all_indexes.size(); iw1l += npol)
{
nlm_tot[ad][dir].insert({all_indexes[iw1l], nlm[dir]});
const int iw1 = all_indexes[iw1l] / npol;
std::vector<std::vector<std::complex<double>>> nlm;
// nlm is a vector of vectors, but size of outer vector is only 1 when out_current is false
// and size of outer vector is 4 when out_current is true (3 for <psi|r_i * exp(-iAr)|beta>, 1 for
// <psi|exp(-iAr)|beta>) inner loop : all projectors (L0,M0)

// snap_psibeta_half_tddft() are used to calculate <psi|exp(-iAr)|beta>
// and <psi|rexp(-iAr)|beta> as well if current are needed

module_tddft::snap_psibeta_half_tddft(orb_,
this->ucell->infoNL,
nlm,
tau1 * this->ucell->lat0,
T1,
atom1->iw2l[iw1],
atom1->iw2m[iw1],
atom1->iw2n[iw1],
tau0 * this->ucell->lat0,
T0,
this->cart_At,
true);
for (int dir = 0; dir < 4; dir++)
{
nlm_tot[ad][dir].insert({all_indexes[iw1l], nlm[dir]});
}
}
}
}
// 2. calculate <psi_I|beta>D<beta|psi_{J,R}> for each pair of <IJR> atoms
for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1)
{
const int T1 = adjs.ntype[ad1];
const int I1 = adjs.natom[ad1];
const int iat1 = ucell->itia2iat(T1, I1);
#ifdef _OPENMP
if(atom_row_list.find(iat1) == atom_row_list.end())

#ifdef _OPENMP
// record the iat number of the adjacent atoms
std::set<int> ad_atom_set;
for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
{
continue;
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
ad_atom_set.insert(iat1);
}
#endif
ModuleBase::Vector3<int>& R_index1 = adjs.box[ad1];
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ++ad2)

// split the ad_atom_set into num_threads parts
const int num_threads = omp_get_num_threads();
const int thread_id = omp_get_thread_num();
std::set<int> ad_atom_set_thread;
int i = 0;
for(const auto iat1 : ad_atom_set)
{
const int T2 = adjs.ntype[ad2];
const int I2 = adjs.natom[ad2];
const int iat2 = ucell->itia2iat(T2, I2);
ModuleBase::Vector3<int>& R_index2 = adjs.box[ad2];
ModuleBase::Vector3<int> R_vector(R_index2[0] - R_index1[0],
R_index2[1] - R_index1[1],
R_index2[2] - R_index1[2]);
std::complex<double>* tmp_c[3] = {nullptr, nullptr, nullptr};
for (int i = 0; i < 3; i++)
if (i % num_threads == thread_id)
{
tmp_c[i] = this->current_term[i]->find_matrix(iat1, iat2, R_vector[0], R_vector[1], R_vector[2])->get_pointer();
ad_atom_set_thread.insert(iat1);
}
// if not found , skip this pair of atoms
if (tmp_c[0] != nullptr)
i++;
}
#endif

// 2. calculate <psi_I|beta>D<beta|psi_{J,R}> for each pair of <IJR> atoms
for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1)
{
const int T1 = adjs.ntype[ad1];
const int I1 = adjs.natom[ad1];
const int iat1 = ucell->itia2iat(T1, I1);
#ifdef _OPENMP
if (ad_atom_set_thread.find(iat1) == ad_atom_set_thread.end())
{
continue;
}
#endif
ModuleBase::Vector3<int>& R_index1 = adjs.box[ad1];
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ++ad2)
{
this->cal_vcomm_r_IJR(iat1,
iat2,
T0,
paraV,
nlm_tot[ad1],
nlm_tot[ad2],
tmp_c);
const int T2 = adjs.ntype[ad2];
const int I2 = adjs.natom[ad2];
const int iat2 = ucell->itia2iat(T2, I2);
ModuleBase::Vector3<int>& R_index2 = adjs.box[ad2];
ModuleBase::Vector3<int> R_vector(R_index2[0] - R_index1[0],
R_index2[1] - R_index1[1],
R_index2[2] - R_index1[2]);
std::complex<double>* tmp_c[3] = {nullptr, nullptr, nullptr};
for (int i = 0; i < 3; i++)
{
tmp_c[i] = this->current_term[i]->find_matrix(iat1, iat2, R_vector[0], R_vector[1], R_vector[2])->get_pointer();
}
// if not found , skip this pair of atoms
if (tmp_c[0] != nullptr)
{
this->cal_vcomm_r_IJR(iat1,
iat2,
T0,
paraV,
nlm_tot[ad1],
nlm_tot[ad2],
tmp_c);
}
}
}
}
}
#ifdef _OPENMP
}
#endif
ModuleBase::timer::tick("TD_current", "calculate_vcomm_r");
}

Expand Down

0 comments on commit 3f41ce4

Please sign in to comment.