Skip to content

Commit

Permalink
add pardiso tests for PD matrices.
Browse files Browse the repository at this point in the history
fix test
  • Loading branch information
jcapriot committed Oct 10, 2024
1 parent 46683d6 commit 54156ff
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions tests/test_Pardiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,61 @@ def test_mat_data():

@pytest.mark.parametrize('transpose', [True, False])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
@pytest.mark.parametrize('symmetric', [True, False])
def test_solve(test_mat_data, dtype, transpose, symmetric):
@pytest.mark.parametrize('symmetry', ["S", "H", None])
def test_solve(test_mat_data, dtype, transpose, symmetry):

A, sol = test_mat_data
sol = sol.astype(dtype)
A = A.astype(dtype)
if not symmetric:

if symmetry is None:
D = sp.diags(np.linspace(2, 3, A.shape[0]))
A = D @ A
symmetric = False
hermitian = False
elif symmetry == "H":
D = sp.diags(np.linspace(2, 3, A.shape[0]))
if np.issubdtype(dtype, np.complexfloating):
D = D + 1j * sp.diags(np.linspace(3, 4, A.shape[0]))
A = D @ A @ D.T.conjugate()
symmetric = False
hermitian = True
else:
symmetric = True
hermitian = False

sol = sol.astype(dtype)
A = A.astype(dtype)

rhs = A @ sol
if transpose:
Ainv = pymatsolver.Pardiso(A.T, is_symmetric=symmetric).T
Ainv = pymatsolver.Pardiso(A.T, is_symmetric=symmetric, is_hermitian=hermitian).T
else:
Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric)
Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric, is_hermitian=hermitian)
for i in range(rhs.shape[1]):
npt.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL)
npt.assert_allclose(Ainv * rhs, sol, atol=TOL)

@pytest.mark.parametrize('transpose', [True, False])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_pardiso_positive_definite(dtype, transpose):
n = 5
if dtype == np.float64:
L = sp.diags([1, -1], [0, -1], shape=(n, n))
else:
L = sp.diags([1, -1j], [0, -1], shape=(n, n))
D = sp.diags(np.linspace(1, 2, n))
A_pd = L @ D @ (L.T.conjugate())

sol = np.linspace(0, 1, n)
rhs = A_pd @ sol

is_symmetric = dtype == np.float64
if transpose:
Ainv = pymatsolver.Pardiso(A.T, is_symmetric=is_symmetric, is_hermitian=True, is_positive_definite=True).T
else:
Ainv = pymatsolver.Pardiso(A, is_symmetric=is_symmetric, is_hermitian=True, is_positive_definite=True)

npt.assert_allclose(Ainv @ rhs, sol)


def test_refactor(test_mat_data):
A, sol = test_mat_data
Expand Down

0 comments on commit 54156ff

Please sign in to comment.