Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RandomVariable graphs with scalar shape parameters in JAX backend #1029

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2024

This should make it possible to do forward sampling in more PyMC models that use dims to define variables shapes

def test_random_scalar_shape_input():
    dim0 = pt.scalar("dim0", dtype=int)
    dim1 = pt.scalar("dim1", dtype=int)

    out = pt.random.normal(0, 1, size=dim0)
    jax_fn = compile_random_function([dim0], out)
    assert jax_fn(np.array(2)).shape == (2,)
    assert jax_fn(np.array(3)).shape == (3,)

    out = pt.random.normal(0, 1, size=[dim0, dim1])
    jax_fn = compile_random_function([dim0, dim1], out)
    assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
    assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)

These was already a special rewrite to replace make_vector, expand_dims in the shape of RVs, but without handling these inputs from the outside it wouldn't achieve much for PyTensor users:

@node_rewriter([RandomVariable])
def size_parameter_as_tuple(fgraph, node):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `size` or `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
size_arg = node.inputs[1]
size_node = size_arg.owner
if size_node is None:
return
if isinstance(size_node.op, JAXShapeTuple):
return
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
new_size_args = JAXShapeTuple()(*size_node.inputs)
new_inputs = list(node.inputs)
new_inputs[1] = new_size_args
new_node = node.clone_with_new_inputs(new_inputs)
return new_node.outputs


📚 Documentation preview 📚: https://pytensor--1029.org.readthedocs.build/en/1029/

@ricardoV94 ricardoV94 added enhancement New feature or request jax labels Oct 11, 2024
@ricardoV94 ricardoV94 changed the title Allow running RandomVariable graphs with scalar shape parameters in JAX backend Support RandomVariable graphs with scalar shape parameters in JAX backend Oct 11, 2024
Copy link

codecov bot commented Oct 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.87%. Comparing base (b248eba) to head (9b9bcba).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1029      +/-   ##
==========================================
- Coverage   81.89%   81.87%   -0.02%     
==========================================
  Files         182      182              
  Lines       47778    47790      +12     
  Branches     8597     8598       +1     
==========================================
+ Hits        39126    39130       +4     
- Misses       6487     6491       +4     
- Partials     2165     2169       +4     
Files with missing lines Coverage Δ
pytensor/link/jax/linker.py 96.29% <100.00%> (+1.05%) ⬆️

... and 1 file with indirect coverage changes

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 11, 2024

This solves pymc-devs/pymc#7348

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant