[ADD] Minimal linear operator interface for PyTorch #130
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Long-term, I want to add native PyTorch support for linear operators in
curvlinops
to address inefficiencies like #71, but also to clearly separate PyTorch from SciPy so that it will be easier to tackle features like supporting distributed settings.This PR is a first step towards this goal.
From an API perspective, I plan to keep the constructor of all existing linear operators identical. The only backward-incompatible change will be that the produced linear operator will be purely PyTorch. To obtain the old behaviour one has to call
.to_scipy()
after the constructor.Old:
H = HessianLinearOperator(...)
Planned new:
H = HessianLinearOperator(...).to_scipy(dtype=...)
The PR defines a linear operator interface in PyTorch which allows easy export to SciPy linear operators.
Importantly, the interface can multiply onto vectors/matrices represented by single
Tensor
s, or aList[Tensor]
, which is more common in PyTorch. It verifies the input and output formats and all methods that need to be implemented assume the (more natural) tensor list format.The next steps will be:
CurvatureLinearOperator
that replicatescurvlinops._base._LinearOperator
but inherits from ourPyTorchLinearOperator
, rather thanscipy.sparse.linalg.LinearOperator
.CurvatureLinearOperator
. I already tried that for the Hessian and was able to migrate without breaking the tests. I will set up a separate PR to keep the diffs manageablecurvlinops._base
.Let me know if this makes sense.