Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Implement adjoint for sub-matrix linear operator #115

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions curvlinops/submatrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implements slices of linear operators."""

from __future__ import annotations

from typing import List

from numpy import column_stack, ndarray, zeros
Expand Down Expand Up @@ -78,3 +80,13 @@ def _matmat(self, X: ndarray) -> ndarray:
``A[row_idxs, :][:, col_idxs] @ x``. Has shape ``[len(row_idxs), N]``.
"""
return column_stack([self @ col for col in X.T])

def _adjoint(self) -> SubmatrixLinearOperator:
"""Return the adjoint of the sub-matrix.

For that, we need to take the adjoint operator, and swap row and column indices.

Returns:
The linear operator for the adjoint sub-matrix.
"""
return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ lint =

# Dependencies needed to build/view the documentation (semicolon/line-separated)
docs =
setuptools==69.5.1 # RTD fails with setuptools>=70, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15863
transformers
datasets
matplotlib
Expand Down
45 changes: 38 additions & 7 deletions test/test_submatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Tuple

from numpy import eye, ndarray, random
from pytest import fixture, raises
from pytest import fixture, mark, raises
from scipy.sparse.linalg import aslinearoperator

from curvlinops.examples.utils import report_nonclose
Expand Down Expand Up @@ -34,29 +34,60 @@ def submatrix_case(request) -> Tuple[ndarray, List[int], List[int]]:
return case["A_fn"](), case["row_idxs_fn"](), case["col_idxs_fn"]()


def test_SubmatrixLinearOperator__matvec(submatrix_case):
@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"])
def test_SubmatrixLinearOperator__matvec(
submatrix_case: Tuple[ndarray, List[int], List[int]], adjoint: bool
):
"""Test the matrix-vector multiplication of a submatrix linear operator.

Args:
submatrix_case: A tuple with a random matrix and two index lists.
adjoint: Whether to take the operator's adjoint before multiplying.
"""
A, row_idxs, col_idxs = submatrix_case

A_sub = A[row_idxs, :][:, col_idxs]
A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs)

x = random.rand(len(col_idxs))
if adjoint:
A_sub = A_sub.conj().T
A_sub_linop = A_sub_linop.adjoint()

x = random.rand(A_sub.shape[1])
A_sub_linop_x = A_sub_linop @ x

assert A_sub_linop_x.shape == (len(row_idxs),)
assert A_sub_linop_x.shape == ((len(col_idxs),) if adjoint else (len(row_idxs),))
report_nonclose(A_sub @ x, A_sub_linop_x)


def test_SubmatrixLinearOperator__matmat(submatrix_case, num_vecs: int = 3):
@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"])
def test_SubmatrixLinearOperator__matmat(
submatrix_case: Tuple[ndarray, List[int], List[int]],
adjoint: bool,
num_vecs: int = 3,
):
"""Test the matrix-matrix multiplication of a submatrix linear operator.

Args:
submatrix_case: A tuple with a random matrix and two index lists.
adjoint: Whether to take the operator's adjoint before multiplying.
num_vecs: The number of vectors to multiply. Default: ``3``.
"""
A, row_idxs, col_idxs = submatrix_case

A_sub = A[row_idxs, :][:, col_idxs]
A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs)

X = random.rand(len(col_idxs), num_vecs)
if adjoint:
A_sub = A_sub.conj().T
A_sub_linop = A_sub_linop.adjoint()

X = random.rand(A_sub.shape[1], num_vecs)
A_sub_linop_X = A_sub_linop @ X

assert A_sub_linop_X.shape == (len(row_idxs), num_vecs)
assert A_sub_linop_X.shape == (
(len(col_idxs), num_vecs) if adjoint else (len(row_idxs), num_vecs)
)
report_nonclose(A_sub @ X, A_sub_linop_X)


Expand Down
Loading