Skip to content

Commit

Permalink
[pre-commit.ci lite] apply automatic fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci-lite[bot] authored Jul 3, 2024
1 parent c2ff90f commit 0bbe40c
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 144 deletions.
4 changes: 2 additions & 2 deletions source/module_hamilt_lcao/module_gint/grid_technique.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Grid_Technique : public Grid_MeshBall {
const int& iat2) const;

private:
void cal_max_box_index(void);
void cal_max_box_index();

int maxB1;
int maxB2;
Expand All @@ -144,7 +144,7 @@ class Grid_Technique : public Grid_MeshBall {
const int& startz_current,
const UnitCell& ucell);
void init_atoms_on_grid2(const int* index2normal, const UnitCell& ucell);
void cal_grid_integration_index(void);
void cal_grid_integration_index();
void cal_trace_lo(const UnitCell& ucell);
void check_bigcell(int* ind_bigcell, char* bigcell_on_processor);
void get_startind(const int& ny,
Expand Down
151 changes: 86 additions & 65 deletions source/module_hamilt_lcao/module_gint/mult_psi_dm_new.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#include "gint_tools.h"
#include "module_base/timer.h"
#include "module_base/ylm.h"
namespace Gint_Tools{
namespace Gint_Tools {
void mult_psi_DM_new(
const Grid_Technique& gt, const int bxyz, const int& grid_index,
const Grid_Technique& gt,
const int bxyz,
const int& grid_index,
const int na_grid, // how many atoms on this (i,j,k) grid
const int LD_pool,
const int* const block_iw, // block_iw[na_grid], index of wave functions for each block
const int* const block_size, // block_size[na_grid], number of columns of a band
const int* const block_index, // block_index[na_grid+1], count total number of atomis orbitals
const bool* const* const cal_flag, // cal_flag[bxyz][na_grid], whether the atom-grid distance is larger than cutoff
const double* const* const psi, // psir_vlbr3[bxyz][LD_pool]
double** psi_DM, const hamilt::HContainer<double>* DM,
const int* const
block_iw, // block_iw[na_grid], index of wave functions for each block
const int* const
block_size, // block_size[na_grid], number of columns of a band
const int* const block_index, // block_index[na_grid+1], count total number
// of atomis orbitals
const bool* const* const
cal_flag, // cal_flag[bxyz][na_grid], whether the atom-grid distance
// is larger than cutoff
const double* const* const psi, // psir_vlbr3[bxyz][LD_pool]
double** psi_DM,
const hamilt::HContainer<double>* DM,
const bool if_symm) // 1: density, 2: force
{
constexpr char side = 'L', uplo = 'U';
Expand All @@ -20,8 +28,7 @@ void mult_psi_DM_new(
constexpr int inc = 1;
double alpha_gemm = if_symm ? 2.0 : 1.0;

for (int ia1 = 0; ia1 < na_grid; ia1++)
{
for (int ia1 = 0; ia1 < na_grid; ia1++) {
const int mcell_index1 = gt.bcell_start[grid_index] + ia1;
const int iat1 = gt.which_atom[mcell_index1];
const double* tmp_matrix = DM->find_pair(iat1, iat1)->get_pointer(0);
Expand All @@ -32,18 +39,14 @@ void mult_psi_DM_new(
// ia1==ia2, diagonal part
// find the first ib and last ib for non-zeros cal_flag
int first_ib = 0, last_ib = 0;
for (int ib = 0; ib < bxyz; ++ib)
{
if (cal_flag[ib][ia1])
{
for (int ib = 0; ib < bxyz; ++ib) {
if (cal_flag[ib][ia1]) {
first_ib = ib;
break;
}
}
for (int ib = bxyz - 1; ib >= 0; --ib)
{
if (cal_flag[ib][ia1])
{
for (int ib = bxyz - 1; ib >= 0; --ib) {
if (cal_flag[ib][ia1]) {
last_ib = ib + 1;
break;
}
Expand All @@ -53,43 +56,51 @@ void mult_psi_DM_new(
continue;

int cal_num = 0;
for (int ib = first_ib; ib < last_ib; ++ib)
{
for (int ib = first_ib; ib < last_ib; ++ib) {
cal_num += cal_flag[ib][ia1];
}
// if enough cal_flag is nonzero
if (cal_num > ib_length / 4)
{
dsymm_(&side, &uplo, &block_size[ia1], &ib_length, &alpha_symm, tmp_matrix, &block_size[ia1],
&psi[first_ib][block_index[ia1]], &LD_pool, &beta, &psi_DM[first_ib][block_index[ia1]],
if (cal_num > ib_length / 4) {
dsymm_(&side,
&uplo,
&block_size[ia1],
&ib_length,
&alpha_symm,
tmp_matrix,
&block_size[ia1],
&psi[first_ib][block_index[ia1]],
&LD_pool,
&beta,
&psi_DM[first_ib][block_index[ia1]],
&LD_pool);
}
else
{
} else {
// int k=1;
for (int ib = first_ib; ib < last_ib; ++ib)
{
if (cal_flag[ib][ia1])
{
dsymv_(&uplo, &block_size[ia1], &alpha_symm, tmp_matrix, &block_size[ia1],
&psi[ib][block_index[ia1]], &inc, &beta, &psi_DM[ib][block_index[ia1]], &inc);
for (int ib = first_ib; ib < last_ib; ++ib) {
if (cal_flag[ib][ia1]) {
dsymv_(&uplo,
&block_size[ia1],
&alpha_symm,
tmp_matrix,
&block_size[ia1],
&psi[ib][block_index[ia1]],
&inc,
&beta,
&psi_DM[ib][block_index[ia1]],
&inc);
}
}
}
}

int start = if_symm ? ia1 + 1 : 0;

for (int ia2 = start; ia2 < na_grid; ia2++)
{
for (int ia2 = start; ia2 < na_grid; ia2++) {
//---------------------------------------------
// check if we need to calculate the big cell.
//---------------------------------------------
bool same_flag = false;
for (int ib = 0; ib < gt.bxyz; ++ib)
{
if (cal_flag[ib][ia1] && cal_flag[ib][ia2])
{
for (int ib = 0; ib < gt.bxyz; ++ib) {
if (cal_flag[ib][ia1] && cal_flag[ib][ia2]) {
same_flag = true;
break;
}
Expand All @@ -100,20 +111,17 @@ void mult_psi_DM_new(

const int bcell2 = gt.bcell_start[grid_index] + ia2;
const int iat2 = gt.which_atom[bcell2];
const double* tmp_matrix = DM->find_pair(iat1, iat2)->get_pointer(0);
const double* tmp_matrix
= DM->find_pair(iat1, iat2)->get_pointer(0);
int first_ib = 0, last_ib = 0;
for (int ib = 0; ib < bxyz; ++ib)
{
if (cal_flag[ib][ia1] && cal_flag[ib][ia2])
{
for (int ib = 0; ib < bxyz; ++ib) {
if (cal_flag[ib][ia1] && cal_flag[ib][ia2]) {
first_ib = ib;
break;
}
}
for (int ib = bxyz - 1; ib >= 0; --ib)
{
if (cal_flag[ib][ia1] && cal_flag[ib][ia2])
{
for (int ib = bxyz - 1; ib >= 0; --ib) {
if (cal_flag[ib][ia1] && cal_flag[ib][ia2]) {
last_ib = ib + 1;
break;
}
Expand All @@ -123,29 +131,42 @@ void mult_psi_DM_new(
continue;

int cal_pair_num = 0;
for (int ib = first_ib; ib < last_ib; ++ib)
{
for (int ib = first_ib; ib < last_ib; ++ib) {
cal_pair_num += cal_flag[ib][ia1] && cal_flag[ib][ia2];
}
const int iw2_lo = block_iw[ia2];
if (cal_pair_num > ib_length / 4)
{
dgemm_(&transa, &transb, &block_size[ia2], &ib_length, &block_size[ia1], &alpha_gemm, tmp_matrix,
&block_size[ia2], &psi[first_ib][block_index[ia1]], &LD_pool, &beta,
&psi_DM[first_ib][block_index[ia2]], &LD_pool);
}
else
{
for (int ib = first_ib; ib < last_ib; ++ib)
{
if (cal_flag[ib][ia1] && cal_flag[ib][ia2])
{
dgemv_(&transa, &block_size[ia2], &block_size[ia1], &alpha_gemm, tmp_matrix, &block_size[ia2],
&psi[ib][block_index[ia1]], &inc, &beta, &psi_DM[ib][block_index[ia2]], &inc);
if (cal_pair_num > ib_length / 4) {
dgemm_(&transa,
&transb,
&block_size[ia2],
&ib_length,
&block_size[ia1],
&alpha_gemm,
tmp_matrix,
&block_size[ia2],
&psi[first_ib][block_index[ia1]],
&LD_pool,
&beta,
&psi_DM[first_ib][block_index[ia2]],
&LD_pool);
} else {
for (int ib = first_ib; ib < last_ib; ++ib) {
if (cal_flag[ib][ia1] && cal_flag[ib][ia2]) {
dgemv_(&transa,
&block_size[ia2],
&block_size[ia1],
&alpha_gemm,
tmp_matrix,
&block_size[ia2],
&psi[ib][block_index[ia1]],
&inc,
&beta,
&psi_DM[ib][block_index[ia2]],
&inc);
}
}
}
} // ia2
} // ia1
}
}
} // namespace Gint_Tools
Loading

0 comments on commit 0bbe40c

Please sign in to comment.