Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: optimize td_current #5181

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 102 additions & 89 deletions source/module_hamilt_lcao/module_tddft/td_current.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "module_hamilt_lcao/module_tddft/snap_psibeta_half_tddft.h"
#ifdef _OPENMP
#include <unordered_set>
#include <omp.h>
#endif

TD_current::TD_current(const UnitCell* ucell_in,
Expand Down Expand Up @@ -161,7 +162,7 @@ void TD_current::initialize_grad_term(Grid_Driver* GridD, const Parallel_Orbital
this->current_term[dir]->allocate(nullptr, true);
}

ModuleBase::timer::tick("EkineticNew", "initialize_HR");
ModuleBase::timer::tick("TD_current", "initialize_grad_term");
}

void TD_current::calculate_vcomm_r()
Expand All @@ -171,17 +172,8 @@ void TD_current::calculate_vcomm_r()

const Parallel_Orbitals* paraV = this->current_term[0]->get_atom_pair(0).get_paraV();
const int npol = this->ucell->get_npol();

// 1. calculate <psi|beta> for each pair of atoms
#ifdef _OPENMP
#pragma omp parallel
{
std::unordered_set<int> atom_row_list;
#pragma omp for
for (int iat0 = 0; iat0 < this->ucell->nat; iat0++)
{
atom_row_list.insert(iat0);
}
#endif
for (int iat0 = 0; iat0 < this->ucell->nat; iat0++)
{
auto tau0 = ucell->get_tau(iat0);
Expand All @@ -196,99 +188,120 @@ void TD_current::calculate_vcomm_r()
nlm_tot[i].resize(4);
}

for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
#pragma omp parallel
{
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
const ModuleBase::Vector3<double>& tau1 = adjs.adjacent_tau[ad];
const Atom* atom1 = &ucell->atoms[T1];

auto all_indexes = paraV->get_indexes_row(iat1);
#ifdef _OPENMP
if(atom_row_list.find(iat1) == atom_row_list.end())
#pragma omp for schedule(dynamic)
for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
{
all_indexes.clear();
}
#endif
auto col_indexes = paraV->get_indexes_col(iat1);
// insert col_indexes into all_indexes to get universal set with no repeat elements
all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end());
std::sort(all_indexes.begin(), all_indexes.end());
all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end());
for (int iw1l = 0; iw1l < all_indexes.size(); iw1l += npol)
{
const int iw1 = all_indexes[iw1l] / npol;
std::vector<std::vector<std::complex<double>>> nlm;
// nlm is a vector of vectors, but size of outer vector is only 1 when out_current is false
// and size of outer vector is 4 when out_current is true (3 for <psi|r_i * exp(-iAr)|beta>, 1 for
// <psi|exp(-iAr)|beta>) inner loop : all projectors (L0,M0)

// snap_psibeta_half_tddft() are used to calculate <psi|exp(-iAr)|beta>
// and <psi|rexp(-iAr)|beta> as well if current are needed

module_tddft::snap_psibeta_half_tddft(orb_,
this->ucell->infoNL,
nlm,
tau1 * this->ucell->lat0,
T1,
atom1->iw2l[iw1],
atom1->iw2m[iw1],
atom1->iw2n[iw1],
tau0 * this->ucell->lat0,
T0,
this->cart_At,
true);
for (int dir = 0; dir < 4; dir++)
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
const ModuleBase::Vector3<double>& tau1 = adjs.adjacent_tau[ad];
const Atom* atom1 = &ucell->atoms[T1];
auto all_indexes = paraV->get_indexes_row(iat1);
auto col_indexes = paraV->get_indexes_col(iat1);
// insert col_indexes into all_indexes to get universal set with no repeat elements
all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end());
std::sort(all_indexes.begin(), all_indexes.end());
all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end());
for (int iw1l = 0; iw1l < all_indexes.size(); iw1l += npol)
{
nlm_tot[ad][dir].insert({all_indexes[iw1l], nlm[dir]});
const int iw1 = all_indexes[iw1l] / npol;
std::vector<std::vector<std::complex<double>>> nlm;
// nlm is a vector of vectors, but size of outer vector is only 1 when out_current is false
// and size of outer vector is 4 when out_current is true (3 for <psi|r_i * exp(-iAr)|beta>, 1 for
// <psi|exp(-iAr)|beta>) inner loop : all projectors (L0,M0)

// snap_psibeta_half_tddft() are used to calculate <psi|exp(-iAr)|beta>
// and <psi|rexp(-iAr)|beta> as well if current are needed

module_tddft::snap_psibeta_half_tddft(orb_,
this->ucell->infoNL,
nlm,
tau1 * this->ucell->lat0,
T1,
atom1->iw2l[iw1],
atom1->iw2m[iw1],
atom1->iw2n[iw1],
tau0 * this->ucell->lat0,
T0,
this->cart_At,
true);
for (int dir = 0; dir < 4; dir++)
{
nlm_tot[ad][dir].insert({all_indexes[iw1l], nlm[dir]});
}
}
}
}
// 2. calculate <psi_I|beta>D<beta|psi_{J,R}> for each pair of <IJR> atoms
for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1)
{
const int T1 = adjs.ntype[ad1];
const int I1 = adjs.natom[ad1];
const int iat1 = ucell->itia2iat(T1, I1);
#ifdef _OPENMP
if(atom_row_list.find(iat1) == atom_row_list.end())

#ifdef _OPENMP
// record the iat number of the adjacent atoms
std::set<int> ad_atom_set;
for (int ad = 0; ad < adjs.adj_num + 1; ++ad)
{
continue;
const int T1 = adjs.ntype[ad];
const int I1 = adjs.natom[ad];
const int iat1 = ucell->itia2iat(T1, I1);
ad_atom_set.insert(iat1);
}
#endif
ModuleBase::Vector3<int>& R_index1 = adjs.box[ad1];
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ++ad2)

// split the ad_atom_set into num_threads parts
const int num_threads = omp_get_num_threads();
const int thread_id = omp_get_thread_num();
std::set<int> ad_atom_set_thread;
int i = 0;
for(const auto iat1 : ad_atom_set)
{
const int T2 = adjs.ntype[ad2];
const int I2 = adjs.natom[ad2];
const int iat2 = ucell->itia2iat(T2, I2);
ModuleBase::Vector3<int>& R_index2 = adjs.box[ad2];
ModuleBase::Vector3<int> R_vector(R_index2[0] - R_index1[0],
R_index2[1] - R_index1[1],
R_index2[2] - R_index1[2]);
std::complex<double>* tmp_c[3] = {nullptr, nullptr, nullptr};
for (int i = 0; i < 3; i++)
if (i % num_threads == thread_id)
{
tmp_c[i] = this->current_term[i]->find_matrix(iat1, iat2, R_vector[0], R_vector[1], R_vector[2])->get_pointer();
ad_atom_set_thread.insert(iat1);
}
// if not found , skip this pair of atoms
if (tmp_c[0] != nullptr)
i++;
}
#endif

// 2. calculate <psi_I|beta>D<beta|psi_{J,R}> for each pair of <IJR> atoms
for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1)
{
const int T1 = adjs.ntype[ad1];
const int I1 = adjs.natom[ad1];
const int iat1 = ucell->itia2iat(T1, I1);
#ifdef _OPENMP
if (ad_atom_set_thread.find(iat1) == ad_atom_set_thread.end())
{
continue;
}
#endif
ModuleBase::Vector3<int>& R_index1 = adjs.box[ad1];
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ++ad2)
{
this->cal_vcomm_r_IJR(iat1,
iat2,
T0,
paraV,
nlm_tot[ad1],
nlm_tot[ad2],
tmp_c);
const int T2 = adjs.ntype[ad2];
const int I2 = adjs.natom[ad2];
const int iat2 = ucell->itia2iat(T2, I2);
ModuleBase::Vector3<int>& R_index2 = adjs.box[ad2];
ModuleBase::Vector3<int> R_vector(R_index2[0] - R_index1[0],
R_index2[1] - R_index1[1],
R_index2[2] - R_index1[2]);
std::complex<double>* tmp_c[3] = {nullptr, nullptr, nullptr};
for (int i = 0; i < 3; i++)
{
tmp_c[i] = this->current_term[i]->find_matrix(iat1, iat2, R_vector[0], R_vector[1], R_vector[2])->get_pointer();
}
// if not found , skip this pair of atoms
if (tmp_c[0] != nullptr)
{
this->cal_vcomm_r_IJR(iat1,
iat2,
T0,
paraV,
nlm_tot[ad1],
nlm_tot[ad2],
tmp_c);
}
}
}
}
}
#ifdef _OPENMP
}
#endif
ModuleBase::timer::tick("TD_current", "calculate_vcomm_r");
}

Expand Down
Loading