From f652a023c49a830a1b04a102a445ee84ec66f47c Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 20 Sep 2024 21:58:01 -0400 Subject: [PATCH] [FIX] Linters --- curvlinops/_torch_base.py | 11 +++++++---- test/test__torch_base.py | 36 ++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/curvlinops/_torch_base.py b/curvlinops/_torch_base.py index a4fcced..497c376 100644 --- a/curvlinops/_torch_base.py +++ b/curvlinops/_torch_base.py @@ -128,6 +128,9 @@ def _check_input_and_preprocess( This is useful for post-processing the multiplication's result. is_vec: Whether the input is a vector or a matrix. num_vecs: The number of vectors represented by the input. + + Raises: + ValueError: If the input format is invalid. """ if isinstance(X, Tensor): list_format = False @@ -295,10 +298,10 @@ def to_scipy(self, dtype: Optional[numpy.dtype] = None) -> LinearOperator: def _infer_device(self) -> device: """Infer the linear operator's device. - Returns: + Returns: # noqa: D402 The device of the linear operator. - Raises: # noqa: D402 + Raises: NotImplementedError: Must be implemented by subclasses. """ raise NotImplementedError @@ -306,10 +309,10 @@ def _infer_device(self) -> device: def _infer_dtype(self) -> dtype: """Infer the linear operator's data type. - Returns: + Returns: # noqa: D402 The data type of the linear operator. - Raises: # noqa: D402 + Raises: NotImplementedError: Must be implemented by subclasses. """ raise NotImplementedError diff --git a/test/test__torch_base.py b/test/test__torch_base.py index 7dfe792..dd4821f 100644 --- a/test/test__torch_base.py +++ b/test/test__torch_base.py @@ -39,44 +39,44 @@ def test_output_formatting(): in_shape = [(2, 3), (4, 5)] out_shape = [(2, 3), (4, 6)] # NOTE that this will trigger an error - I = IdentityLinearOperator(in_shape, out_shape) - assert I._in_shape_flat == [6, 20] - assert I._out_shape_flat == [6, 24] - assert I.shape == (30, 26) + Id = IdentityLinearOperator(in_shape, out_shape) + assert Id._in_shape_flat == [6, 20] + assert Id._out_shape_flat == [6, 24] + assert Id.shape == (30, 26) # using valid input vectors/matrices will trigger errors because we # initialized the identity with different input/output spaces with raises(ValueError): - _ = I @ [zeros(2, 3), zeros(4, 5)] # valid vector in list format + _ = Id @ [zeros(2, 3), zeros(4, 5)] # valid vector in list format with raises(ValueError): - _ = I @ [zeros(2, 3, 6), zeros(4, 5, 6)] # valid matrix in list format + _ = Id @ [zeros(2, 3, 6), zeros(4, 5, 6)] # valid matrix in list format with raises(ValueError): - _ = I @ zeros(26) # valid vector in tensor format + _ = Id @ zeros(26) # valid vector in tensor format with raises(ValueError): - _ = I @ zeros(26, 6) # valid matrix in tensor format + _ = Id @ zeros(26, 6) # valid matrix in tensor format def test_preserve_input_format(): """Test whether the input format is preserved by matrix multiplication.""" in_shape = out_shape = [(2, 3), (4, 5)] - I = IdentityLinearOperator(in_shape, out_shape) - assert I._in_shape_flat == I._out_shape_flat == [6, 20] + Id = IdentityLinearOperator(in_shape, out_shape) + assert Id._in_shape_flat == Id._out_shape_flat == [6, 20] X = [zeros(2, 3), zeros(4, 5)] # vector in tensor list format - IX = I @ X - assert len(IX) == len(X) and all(Ix.allclose(x) for Ix, x in zip(IX, X)) + IdX = Id @ X + assert len(IdX) == len(X) and all(Idx.allclose(x) for Idx, x in zip(IdX, X)) X = [zeros(2, 3, 6), zeros(4, 5, 6)] # matrix in tensor list format - IX = I @ X - assert len(IX) == len(X) and all(Ix.allclose(x) for Ix, x in zip(IX, X)) + IdX = Id @ X + assert len(IdX) == len(X) and all(Idx.allclose(x) for Idx, x in zip(IdX, X)) X = zeros(26) # vector in tensor format - IX = I @ X - assert IX.allclose(X) + IdX = Id @ X + assert IdX.allclose(X) X = zeros(26, 6) # matrix in tensor format - IX = I @ X - assert IX.allclose(X) + IdX = Id @ X + assert IdX.allclose(X)