Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RandomVariable graphs with scalar shape parameters in JAX backend #1029

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from numpy.random import Generator, RandomState

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.link.basic import JITLinker


class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""

def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
super().__init__(*args, **kwargs)

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.random.type import RandomType

shared_rng_inputs = [
Expand Down Expand Up @@ -63,19 +67,41 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)

fgraph_inputs = fgraph.inputs
clients = fgraph.clients
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
scalar_shape_inputs = [
inp
for node in fgraph.apply_nodes
if isinstance(node.op, JAXShapeTuple)
for inp in node.inputs
if inp in fgraph_inputs
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
]
self.scalar_shape_inputs = tuple(
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
)

return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)

def jit_compile(self, fn):
import jax

# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
return jax.jit(fn, static_argnums=static_argnums)
jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)

if not self.scalar_shape_inputs:
return jit_fn

def convert_scalar_shape_inputs(*args):
new_args = [
int(arg) if i in self.scalar_shape_inputs else arg
for i, arg in enumerate(args)
]
return jit_fn(*new_args)

return convert_scalar_shape_inputs

def create_thunk_inputs(self, storage_map):
from pytensor.link.jax.dispatch import jax_typify
Expand Down
Loading
Loading