Skip to content

Commit

Permalink
enable the computation of all zeros in one function call (deepmodelin…
Browse files Browse the repository at this point in the history
…g#3449)

Co-authored-by: wqzhou <[email protected]>
  • Loading branch information
jinzx10 and WHUweiqingzhou authored Jan 18, 2024
1 parent 5692438 commit 2a92dd8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 21 deletions.
49 changes: 34 additions & 15 deletions source/module_base/math_sphbes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ void Sphbes::dsphbesj(const int n,
}
}

void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros)
void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros, const bool return_all)
{
assert( n > 0 );
assert( l >= 0 );
Expand All @@ -818,10 +818,22 @@ void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros)
// This property enables us to use bracketing method recursively
// to find all zeros of j_l from the zeros of j_0.

// if l is odd , j_0 --> j_1 --> j_3 --> j_5 --> ...
// if l is even, j_0 --> j_2 --> j_4 --> j_6 --> ...

int nz = n + (l+1)/2; // number of effective zeros in buffer
// If return_all is true, zeros of j_0, j_1, ..., j_l will all be returned
// such that zeros[l*n+i] is the i-th zero of j_l. As such, it is required
// that the array "zeros" has a size of (l+1)*n.
//
// If return_all is false, only the zeros of j_l will be returned
// and "zeros" is merely required to have a size of n.
// Note that in this case the bracketing method can be applied with a stride
// of 2 instead of 1:
// j_0 --> j_1 --> j_3 --> j_5 --> ... --> j_l (odd l)
// j_0 --> j_2 --> j_4 --> j_6 --> ... --> j_l (even l)

// Every recursion step reduces the number of zeros by 1.
// If return_all is true, one needs to start with n+l zeros of j_0
// to ensure n zeros of j_l; otherwise with a stride of 2 one only
// needs to start with n+(l+1)/2 zeros of j_0
int nz = n + ( return_all ? l : (l+1)/2 );
double* buffer = new double[nz];

// zeros of j_0 = sin(x)/x is just n*pi
Expand All @@ -831,27 +843,34 @@ void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros)
buffer[i] = (i+1) * PI;
}

int ll = 1;
int ll; // active l
auto jl = [&ll] (double x) { return sphbesj(ll, x); };

if (l % 2 == 1)
int stride;
std::function<void()> copy_if_needed;
int offset = 0; // keeps track of the position in zeros for next copy (used when return_all == true)
if (return_all)
{
for (int i = 0; i < nz-1; i++)
{
buffer[i] = illinois(jl, buffer[i], buffer[i+1], 1e-15, 50);
}
--nz;
copy_if_needed = [&](){ std::copy(buffer, buffer + n, zeros + offset); offset += n; };
stride = 1;
ll = 1;
}
else
{
copy_if_needed = [](){};
stride = 2;
ll = 2 - l % 2;
}

for (ll = 2 + l%2; ll <= l; ll += 2, --nz)
for (; ll <= l; ll += stride, --nz)
{
copy_if_needed();
for (int i = 0; i < nz-1; i++)
{
buffer[i] = illinois(jl, buffer[i], buffer[i+1], 1e-15, 50);
}
}

std::copy(buffer, buffer + n, zeros);
std::copy(buffer, buffer + n, zeros + offset);
delete[] buffer;
}

Expand Down
13 changes: 9 additions & 4 deletions source/module_base/math_sphbes.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,18 @@ class Sphbes
* This function computes the first n positive zeros of the l-th order
* spherical Bessel function of the first kind.
*
* @param[in] l order of the spherical Bessel function
* @param[in] n number of zeros to be computed
* @param[out] zeros on exit, contains the first n positive zeros in ascending order
* @param[in] l (maximum) order of the spherical Bessel function
* @param[in] n number of zeros to be computed (for each j_l if return_all is true)
* @param[out] zeros on exit, contains the positive zeros.
* @param[in] return_all if true, return all zeros from j_0 to j_l such that zeros[l*n+i]
* is the i-th zero of j_l. If false, return only the first n zeros of j_l.
*
* @note The size of array "zeros" must be at least (l+1)*n if return_all is true, and n otherwise.
*/
static void sphbes_zeros(const int l,
const int n,
double* const zeros
double* const zeros,
bool return_all = false
);

private:
Expand Down
16 changes: 14 additions & 2 deletions source/module_base/test/math_sphbes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,27 @@ TEST_F(Sphbes, Zeros)

int lmax = 20;
int nzeros = 500;
double* zeros = new double[nzeros];
double* zeros = new double[nzeros*(lmax+1)];
for (int l = 0; l <= lmax; ++l)
{
ModuleBase::Sphbes::sphbes_zeros(l, nzeros, zeros);
ModuleBase::Sphbes::sphbes_zeros(l, nzeros, zeros, false);
for (int i = 0; i < nzeros; ++i)
{
EXPECT_LT(std::abs(ModuleBase::Sphbes::sphbesj(l, zeros[i])), 1e-14);
}
}


ModuleBase::Sphbes::sphbes_zeros(lmax, nzeros, zeros, true);
for (int l = 0; l <= lmax; ++l)
{
for (int i = 0; i < nzeros; ++i)
{
EXPECT_LT(std::abs(ModuleBase::Sphbes::sphbesj(l, zeros[l*nzeros+i])), 1e-14);
}
}

delete[] zeros;
}

TEST_F(Sphbes, ZerosOld)
Expand Down

0 comments on commit 2a92dd8

Please sign in to comment.