Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Blockwise for matmul #452

Merged
merged 2 commits into from
Sep 24, 2023
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 23, 2023

We use a Blockwised Dot for matmul so that we get gradients for free.

C-performance won't be great for Blockwised Dot, since that doesn't have a C-implementation.
We could Blockwise more specialized Dot22 / GEMM Ops but that code is a bit of a mess at the moment and not useful long term as we deprecate the C backend

Alternatively we probably could use tensor_dot/batched_dot? They are fundamentally different (or it would be rather inefficient to convert one to the other)

Closes #451

Needed for pymc-devs/pymc#6897

@ricardoV94 ricardoV94 added bug Something isn't working enhancement New feature or request NumPy compatibility Op implementation labels Sep 23, 2023
@ricardoV94 ricardoV94 force-pushed the blockwise_for_matmul branch 2 times, most recently from c6b4410 to 7e3c2dd Compare September 23, 2023 09:30
@codecov-commenter
Copy link

Codecov Report

Merging #452 (8d4054b) into main (4efbd19) will decrease coverage by 0.02%.
Report is 6 commits behind head on main.
The diff coverage is 98.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #452      +/-   ##
==========================================
- Coverage   80.76%   80.75%   -0.02%     
==========================================
  Files         159      159              
  Lines       45869    45849      -20     
  Branches    11238    11234       -4     
==========================================
- Hits        37048    37026      -22     
- Misses       6593     6595       +2     
  Partials     2228     2228              
Files Changed Coverage Δ
pytensor/compile/mode.py 84.40% <ø> (ø)
pytensor/ifelse.py 51.56% <ø> (ø)
pytensor/tensor/variable.py 87.42% <83.33%> (-0.34%) ⬇️
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/scan.py 95.91% <100.00%> (+0.02%) ⬆️
pytensor/scalar/basic.py 80.25% <100.00%> (ø)
pytensor/tensor/math.py 90.16% <100.00%> (-0.33%) ⬇️
pytensor/tensor/rewriting/blockwise.py 96.15% <100.00%> (+1.15%) ⬆️
pytensor/tensor/rewriting/uncanonicalize.py 96.21% <100.00%> (+0.24%) ⬆️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I played around with it and everything seems to work as expected. Currently you can't compile graphs with blockwise-matmul into jax/numba, is that beyond the scope of this PR? I thought there was a jax.vectorize-type function that would make that trivial (for the jax case at least)?

elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all this better than a separate _matrix_vector_matmul function? I only ask because BLAS makes the distinction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine, once we ever go into optimizing this further in jax/numba backends we should be able to know which case is which by inspecting the input static types.

@ricardoV94
Copy link
Member Author

I thought there was a jax.vectorize-type function that would make that trivial (for the jax case at least)?

Yet it should be pretty simple. It's in the todo list: #430
Interested in picking it up :) ?

@ricardoV94 ricardoV94 merged commit 071eadd into pymc-devs:main Sep 24, 2023
53 checks passed
@ricardoV94 ricardoV94 deleted the blockwise_for_matmul branch October 12, 2023 08:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request NumPy compatibility Op implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@ operator returns wrong graph for batched matrices
3 participants