Skip to content

Commit

Permalink
[REF] Apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 20, 2024
1 parent f652a02 commit ee34a0f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy
from scipy.sparse.linalg import LinearOperator
from torch import Size, Tensor, cat, device, dtype, from_numpy
from torch import Size, Tensor, as_tensor, cat, device, dtype


class PyTorchLinearOperator:
Expand All @@ -19,7 +19,7 @@ class PyTorchLinearOperator:
One main difference is that the linear operators cannot only multiply
vectors/matrices specified as single PyTorch tensors, but also
vectors/matrices specified in tensor list format. This is common in
PyTorch, where the space a linear operator acts on is a tensor product
PyTorch, where the space a linear operator acts on is a tensor product.
Functions that need to be implemented are ``_matmat`` and ``_adjoint``.
Expand Down Expand Up @@ -342,7 +342,7 @@ def f_scipy(X: numpy.ndarray) -> numpy.ndarray:
The output matrix in NumPy format.
"""
X_dtype = X.dtype
X_torch = from_numpy(X).to(device, dtype)
X_torch = as_tensor(X, dtype=dtype, device=device)
AX_torch = f(X_torch)
return AX_torch.detach().cpu().numpy().astype(X_dtype)

Expand Down

0 comments on commit ee34a0f

Please sign in to comment.