Skip to content

Commit

Permalink
[FIX] Linters
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 21, 2024
1 parent 06ebbf2 commit f652a02
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
11 changes: 7 additions & 4 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def _check_input_and_preprocess(
This is useful for post-processing the multiplication's result.
is_vec: Whether the input is a vector or a matrix.
num_vecs: The number of vectors represented by the input.
Raises:
ValueError: If the input format is invalid.
"""
if isinstance(X, Tensor):
list_format = False
Expand Down Expand Up @@ -295,21 +298,21 @@ def to_scipy(self, dtype: Optional[numpy.dtype] = None) -> LinearOperator:
def _infer_device(self) -> device:
"""Infer the linear operator's device.
Returns:
Returns: # noqa: D402
The device of the linear operator.
Raises: # noqa: D402
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError

def _infer_dtype(self) -> dtype:
"""Infer the linear operator's data type.
Returns:
Returns: # noqa: D402
The data type of the linear operator.
Raises: # noqa: D402
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError
Expand Down
36 changes: 18 additions & 18 deletions test/test__torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,44 +39,44 @@ def test_output_formatting():
in_shape = [(2, 3), (4, 5)]
out_shape = [(2, 3), (4, 6)] # NOTE that this will trigger an error

I = IdentityLinearOperator(in_shape, out_shape)
assert I._in_shape_flat == [6, 20]
assert I._out_shape_flat == [6, 24]
assert I.shape == (30, 26)
Id = IdentityLinearOperator(in_shape, out_shape)
assert Id._in_shape_flat == [6, 20]
assert Id._out_shape_flat == [6, 24]
assert Id.shape == (30, 26)

# using valid input vectors/matrices will trigger errors because we
# initialized the identity with different input/output spaces
with raises(ValueError):
_ = I @ [zeros(2, 3), zeros(4, 5)] # valid vector in list format
_ = Id @ [zeros(2, 3), zeros(4, 5)] # valid vector in list format

with raises(ValueError):
_ = I @ [zeros(2, 3, 6), zeros(4, 5, 6)] # valid matrix in list format
_ = Id @ [zeros(2, 3, 6), zeros(4, 5, 6)] # valid matrix in list format

with raises(ValueError):
_ = I @ zeros(26) # valid vector in tensor format
_ = Id @ zeros(26) # valid vector in tensor format

with raises(ValueError):
_ = I @ zeros(26, 6) # valid matrix in tensor format
_ = Id @ zeros(26, 6) # valid matrix in tensor format


def test_preserve_input_format():
"""Test whether the input format is preserved by matrix multiplication."""
in_shape = out_shape = [(2, 3), (4, 5)]
I = IdentityLinearOperator(in_shape, out_shape)
assert I._in_shape_flat == I._out_shape_flat == [6, 20]
Id = IdentityLinearOperator(in_shape, out_shape)
assert Id._in_shape_flat == Id._out_shape_flat == [6, 20]

X = [zeros(2, 3), zeros(4, 5)] # vector in tensor list format
IX = I @ X
assert len(IX) == len(X) and all(Ix.allclose(x) for Ix, x in zip(IX, X))
IdX = Id @ X
assert len(IdX) == len(X) and all(Idx.allclose(x) for Idx, x in zip(IdX, X))

X = [zeros(2, 3, 6), zeros(4, 5, 6)] # matrix in tensor list format
IX = I @ X
assert len(IX) == len(X) and all(Ix.allclose(x) for Ix, x in zip(IX, X))
IdX = Id @ X
assert len(IdX) == len(X) and all(Idx.allclose(x) for Idx, x in zip(IdX, X))

X = zeros(26) # vector in tensor format
IX = I @ X
assert IX.allclose(X)
IdX = Id @ X
assert IdX.allclose(X)

X = zeros(26, 6) # matrix in tensor format
IX = I @ X
assert IX.allclose(X)
IdX = Id @ X
assert IdX.allclose(X)

0 comments on commit f652a02

Please sign in to comment.