Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Minimal linear operator interface for PyTorch #130

Merged
merged 3 commits into from
Oct 25, 2024
Merged

Conversation

f-dangel
Copy link
Owner

@f-dangel f-dangel commented Sep 21, 2024

Long-term, I want to add native PyTorch support for linear operators in curvlinops to address inefficiencies like #71, but also to clearly separate PyTorch from SciPy so that it will be easier to tackle features like supporting distributed settings.

This PR is a first step towards this goal.

From an API perspective, I plan to keep the constructor of all existing linear operators identical. The only backward-incompatible change will be that the produced linear operator will be purely PyTorch. To obtain the old behaviour one has to call .to_scipy() after the constructor.

Old: H = HessianLinearOperator(...)
Planned new: H = HessianLinearOperator(...).to_scipy(dtype=...)

The PR defines a linear operator interface in PyTorch which allows easy export to SciPy linear operators.
Importantly, the interface can multiply onto vectors/matrices represented by single Tensors, or a List[Tensor], which is more common in PyTorch. It verifies the input and output formats and all methods that need to be implemented assume the (more natural) tensor list format.

The next steps will be:

  • Define a base class CurvatureLinearOperator that replicates curvlinops._base._LinearOperator but inherits from our PyTorchLinearOperator, rather than scipy.sparse.linalg.LinearOperator.
  • Migrate each supported linear operator to inherit from CurvatureLinearOperator. I already tried that for the Hessian and was able to migrate without breaking the tests. I will set up a separate PR to keep the diffs manageable
  • Once all operators have been migrated (and probably we can get rid of a lot of boilerplate to check shapes, e.g. in KFAC), we can remove the current base class in curvlinops._base.

Let me know if this makes sense.

@coveralls
Copy link

coveralls commented Sep 21, 2024

Pull Request Test Coverage Report for Build 11423244838

Details

  • 57 of 77 (74.03%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.7%) to 88.295%

Changes Missing Coverage Covered Lines Changed/Added Lines %
curvlinops/_torch_base.py 57 77 74.03%
Totals Coverage Status
Change from base Build 10967050112: -0.7%
Covered Lines: 1403
Relevant Lines: 1589

💛 - Coveralls

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought much about what functions have to implemented but this LGTM! This will definitely be a useful refactor and enable/simplify new functionality.

curvlinops/_torch_base.py Outdated Show resolved Hide resolved
curvlinops/_torch_base.py Outdated Show resolved Hide resolved
test/test__torch_base.py Show resolved Hide resolved
@f-dangel f-dangel merged commit 2b7a745 into main Oct 25, 2024
13 checks passed
@f-dangel f-dangel deleted the pytorch-linop branch October 25, 2024 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants