Skip to content

Commit

Permalink
Feature: Add diago_dav_subspace module to pyabacus (#4883)
Browse files Browse the repository at this point in the history
* <Feature>[pyabacus]: Add diago_dav_subspace module to pyabacus

This commit adds the `diago_dav_subspace` module to the `pyabacus` package, which is responsible for diagonalizing the Hamiltonian matrix using Davison subspace method.

The current Pythonization version targets matrix diagonalization. The original algorithm does not require passing a matrix; instead, it uses a function pointer to pass a linear operator for diagonalization.

So the current implementation still needs to wrap std::function<void(T*, T*, const int, const int, const int, const int)> to a Python callable type, which requires further iteration.

* Add Python function and class signatures for ModuleBase and hsolver to improve readability and enhance type checking in Python

* modify dav_subspace, now the module could work but the result is wrong

* Fix memory issue in Diago_DavSubspace destructor

The destructor of Diago_DavSubspace was causing a memory issue, preventing the program from running correctly. Lines 72-74 were previously commented out to avoid the issue, but this led to a memory leak. This commit addresses the root cause of the memory issue and ensures proper memory deallocation without causing a crash.

- Uncommented lines 72-74 in the destructor.

This ensures that all allocated memory is properly deallocated, preventing memory leaks but still unable to run the `pyabacus::hsolver::dav_subspace()`

* add pbdoc to the pybind file and add document to the python func signatures

* modify some docs

* fixed the memory issue in dav_subspace, add an example(python/pyabacus/diago_matrix.py)

* add pytest to diago_dav_subspace

* delete an matrix

* modify _hsolver.py

* Refactor diagonalization to accept operators instead of explicit matrices

Previously, our module could only accept explicit matrices for diagonalization. I have now modified the code to accept operators as parameters for diagonalization, eliminating the need to store the entire matrix explicitly. This allows us to compute eigenvalues for larger Hamiltonians using the sparse matrix storage provided by `SciPy`.

The operator is a function pointer that accepts a vector and returns a vector. To diagonalize an operator A, we define a function or lambda function `mv_op` such that `mv_op(x) = Ax`, which can then be used in the diagonalization process.

Additionally, I have added a new test case for the Hamiltonian matrix corresponding to H2O. The matrix size is `67024x67024`, which previously could not be stored explicitly due to memory constraints. The updated operator method now supports the computation of eigenvalues for such large matrices.

Test cases have been updated to reflect these changes.

* Refactor: Replace manual loop with std::copy for better readability and potential performance improvement

- Replaced manual loop with `std::copy` to copy elements from `psi_in` to `psi` and from `hpsi_ptr` to `hpsi_out`.
- This change simplifies the code, enhances readability and may improve performance due to optimized standard library `std::copy` implementations.

* Refactor hpsi_func to support matrix-matrix multiplication operator and rename two variables' name

- Modified `hpsi_func` to handle matrix-matrix multiplication instead of vector operations.
- This vectorization significantly improved computation speed.
- Renamed variables in the Python frontend interface:
  - `nbasis` to `dim`
  - `nband` to `num_eigs`
- These changes align with the module's design as a pure mathematical interface.

* Enhance test cases to improve precision requirements

- Updated test cases to ensure absolute error is less than 1e-8.
- Added a new test case for random sparse matrices to validate the implementation.

* update README.md

* update `README.md` and `pyproject.toml`
  • Loading branch information
a1henu committed Aug 16, 2024
1 parent 3d44f1f commit d0b25ea
Show file tree
Hide file tree
Showing 19 changed files with 730 additions and 10 deletions.
31 changes: 30 additions & 1 deletion python/pyabacus/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ find_package(pybind11 CONFIG REQUIRED)
set(ABACUS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../../source")
set(BASE_PATH "${ABACUS_SOURCE_DIR}/module_base")
set(NAO_PATH "${ABACUS_SOURCE_DIR}/module_basis/module_nao")
set(HSOLVER_PATH "${ABACUS_SOURCE_DIR}/module_hsolver")
set(HAMILT_PATH "${ABACUS_SOURCE_DIR}/module_hamilt_general")
set(PSI_PATH "${ABACUS_SOURCE_DIR}/module_psi")
set(ENABLE_LCAO ON)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/../../cmake")

Expand Down Expand Up @@ -100,6 +103,29 @@ list(APPEND _naos
add_library(naopack SHARED
${_naos}
)
# add diago shared library
list(APPEND _diago
${HSOLVER_PATH}/diago_dav_subspace.cpp
${HSOLVER_PATH}/diag_const_nums.cpp
${HSOLVER_PATH}/diago_iter_assist.cpp

${HSOLVER_PATH}/kernels/dngvd_op.cpp
${HSOLVER_PATH}/kernels/math_kernel_op.cpp
# dependency
${BASE_PATH}/module_device/device.cpp
${BASE_PATH}/module_device/memory_op.cpp

${HAMILT_PATH}/operator.cpp
${PSI_PATH}/psi.cpp
)
add_library(diagopack SHARED
${_diago}
)
target_link_libraries(diagopack
PRIVATE
${OpenBLAS_LIBRARIES}
${LAPACK_LIBRARIES}
)
# link math_libs
if(MKLROOT)
target_link_libraries(naopack
Expand All @@ -125,9 +151,10 @@ list(APPEND _sources
${PROJECT_SOURCE_DIR}/src/py_abacus.cpp
${PROJECT_SOURCE_DIR}/src/py_base_math.cpp
${PROJECT_SOURCE_DIR}/src/py_m_nao.cpp
${PROJECT_SOURCE_DIR}/src/py_diago_dav_subspace.cpp
)
pybind11_add_module(_core MODULE ${_sources})
target_link_libraries(_core PRIVATE pybind11::headers naopack)
target_link_libraries(_core PRIVATE pybind11::headers naopack diagopack)
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})
# set RPATH
execute_process(
Expand All @@ -141,5 +168,7 @@ set(TARGET_PACK pyabacus)
set(CMAKE_INSTALL_RPATH "${PYTHON_SITE_PACKAGES}/${TARGET_PACK}")
set_target_properties(_core PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(naopack PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(diagopack PROPERTIES INSTALL_RPATH "$ORIGIN")
install(TARGETS _core naopack DESTINATION ${TARGET_PACK})
install(TARGETS _core diagopack DESTINATION ${TARGET_PACK})

20 changes: 15 additions & 5 deletions python/pyabacus/README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
Build Example: TwoCenterIntegral Section in ABACUS
==============
==================================================

An example project built with [pybind11](https://github.com/pybind/pybind11)
and scikit-build-core. Python 3.7+ (see older commits for older versions of
Python).


Installation
------------

- Create and activate a new conda env, e.g. `conda create -n myenv python=3.8 & conda activate myenv`.
- Clone ABACUS main repository and `cd abacus-develop/python/pyabacus`.
- Build pyabacus by `pip install -v .` or install test dependencies & build pyabacus by `pip install .[test]`. (Use `pip install -v .[test] -i https://pypi.tuna.tsinghua.edu.cn/simple` to accelerate installation process.)


CI Examples
-----------

There are examples for CI in `.github/workflows`. A simple way to produces
binary "wheels" for all platforms is illustrated in the "wheels.yml" file,
using [`cibuildwheel`][].
using .

Use `pytest -v` to run all the unit tests for pyabacus in the local machine.

```shell
$ cd tests/
$ pytest -v
```

Run `python vis_nao.py` to visualize the numerical orbital.

```shell
$ cd examples/
$ python vis_nao.py
Expand All @@ -41,6 +41,16 @@ $ python ex_s_rotate.py
norm(S_e3 - S_numer) = 3.341208104032616e-15
```

Run `python diago_matrix.py` in `examples` to check the diagonalization of a matrix.

```shell
$ cd examples/
$ python diago_matrix.py
eigenvalues calculated by pyabacus: [-0.38440611 0.24221155 0.31593272 0.53144616 0.85155108 1.06950155 1.11142051 1.12462152]
eigenvalues calculated by scipy: [-0.38440611 0.24221155 0.31593272 0.53144616 0.85155108 1.06950154 1.11142051 1.12462151]
error: [9.26164700e-12 2.42959514e-10 2.96529468e-11 7.77933273e-12 7.53686002e-12 2.95628810e-09 1.04678111e-09 7.79106313e-09]
```

License
-------

Expand All @@ -58,4 +68,4 @@ s.sphbesj(1, 0.0)
0.0
```

[`cibuildwheel`]: https://cibuildwheel.readthedocs.io
[`cibuildwheel`]: https://cibuildwheel.readthedocs.io
Binary file added python/pyabacus/examples/Si2.mat
Binary file not shown.
38 changes: 38 additions & 0 deletions python/pyabacus/examples/diago_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pyabacus import hsolver
import numpy as np
import scipy

h_mat = scipy.io.loadmat('./Si2.mat')['Problem']['A'][0, 0]

nbasis = h_mat.shape[0]
nband = 8

v0 = np.random.rand(nbasis, nband)

diag_elem = h_mat.diagonal()
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
precond = 1.0 / np.abs(diag_elem)


def mm_op(x):
return h_mat.dot(x)

e, v = hsolver.dav_subspace(
mm_op,
v0,
nbasis,
nband,
precond,
dav_ndim=8,
tol=1e-8,
max_iter=1000,
scf_type=False
)

print('eigenvalues calculated by pyabacus: ', e)

e_scipy, v_scipy = scipy.sparse.linalg.eigsh(h_mat, k=nband, which='SA', maxiter=1000)
e_scipy = np.sort(e_scipy)
print('eigenvalues calculated by scipy: ', e_scipy)

print('error:', e - e_scipy)
1 change: 1 addition & 0 deletions python/pyabacus/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ description="A minimal two_center_integral package (with pybind11)"
readme = "README.md"
authors = [
{ name = "Jie Li", email = "[email protected]" },
{ name = "Chenxu Bai", email = "[email protected]" },
]
requires-python = ">=3.7"
classifiers = [
Expand Down
2 changes: 2 additions & 0 deletions python/pyabacus/src/py_abacus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ namespace py = pybind11;

void bind_base_math(py::module& m);
void bind_m_nao(py::module& m);
void bind_diago_dav_subspace(py::module& m);

PYBIND11_MODULE(_core, m)
{
m.doc() = "Python extension for ABACUS built with pybind11 and scikit-build.";
bind_base_math(m);
bind_m_nao(m);
bind_diago_dav_subspace(m);
}
93 changes: 93 additions & 0 deletions python/pyabacus/src/py_diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <complex>
#include <functional>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/complex.h>
#include <pybind11/numpy.h>

#include "module_hsolver/diago_dav_subspace.h"
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_base/module_device/types.h"

#include "./py_diago_dav_subspace.hpp"

namespace py = pybind11;
using namespace pybind11::literals;

void bind_diago_dav_subspace(py::module& m)
{
py::module hsolver = m.def_submodule("hsolver");

py::class_<hsolver::diag_comm_info>(hsolver, "diag_comm_info")
.def(py::init<const int, const int>(), "rank"_a, "nproc"_a)
.def_readonly("rank", &hsolver::diag_comm_info::rank)
.def_readonly("nproc", &hsolver::diag_comm_info::nproc);

py::class_<py_hsolver::PyDiagoDavSubspace>(hsolver, "diago_dav_subspace")
.def(py::init<int, int>(), R"pbdoc(
Constructor of diago_dav_subspace, a class for diagonalizing
a linear operator using the Davidson-Subspace Method.
This class serves as a backend computation class. The interface
for invoking this class is a function defined in _hsolver.py,
which uses this class to perform the calculations.
Parameters
----------
nbasis : int
The number of basis functions.
nband : int
The number of bands to be calculated.
)pbdoc", "nbasis"_a, "nband"_a)
.def("diag", &py_hsolver::PyDiagoDavSubspace::diag, R"pbdoc(
Diagonalize the linear operator using the Davidson-Subspace Method.
Parameters
----------
mm_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
The operator to be diagonalized, which is a function that takes a matrix as input
and returns a matrix mv_op(X) = H * X as output.
precond_vec : np.ndarray
The preconditioner vector.
dav_ndim : int
The number of vectors, which is a multiple of the number of
eigenvectors to be calculated.
tol : double
The tolerance for the convergence.
max_iter : int
The maximum number of iterations.
need_subspace : bool
Whether to use the subspace function.
is_occupied : list[bool]
A list of boolean values indicating whether the band is occupied,
meaning that the corresponding eigenvalue is to be calculated.
scf_type : bool
Whether to use the SCF type, which is used to determine the
convergence criterion.
If true, it indicates a self-consistent field (SCF) calculation,
where the initial precision of eigenvalue calculation can be coarse.
If false, it indicates a non-self-consistent field (non-SCF) calculation,
where high precision in eigenvalue calculation is required from the start.
)pbdoc",
"mm_op"_a,
"precond_vec"_a,
"dav_ndim"_a,
"tol"_a,
"max_iter"_a,
"need_subspace"_a,
"is_occupied"_a,
"scf_type"_a,
"comm_info"_a)
.def("set_psi", &py_hsolver::PyDiagoDavSubspace::set_psi, R"pbdoc(
Set the initial guess of the eigenvectors, i.e. the wave functions.
)pbdoc", "psi_in"_a)
.def("get_psi", &py_hsolver::PyDiagoDavSubspace::get_psi, R"pbdoc(
Get the eigenvectors.
)pbdoc")
.def("init_eigenvalue", &py_hsolver::PyDiagoDavSubspace::init_eigenvalue, R"pbdoc(
Initialize the eigenvalues as zero.
)pbdoc")
.def("get_eigenvalue", &py_hsolver::PyDiagoDavSubspace::get_eigenvalue, R"pbdoc(
Get the eigenvalues.
)pbdoc");
}
Loading

0 comments on commit d0b25ea

Please sign in to comment.