Skip to content

Commit

Permalink
Allow running JAX functions with scalar inputs for RV shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 11, 2024
1 parent ec793a9 commit 9b9bcba
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 14 deletions.
40 changes: 33 additions & 7 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from numpy.random import Generator, RandomState

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.link.basic import JITLinker


class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""

def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
super().__init__(*args, **kwargs)

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.random.type import RandomType

shared_rng_inputs = [
Expand Down Expand Up @@ -63,19 +67,41 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)

fgraph_inputs = fgraph.inputs
clients = fgraph.clients
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
scalar_shape_inputs = [
inp
for node in fgraph.apply_nodes
if isinstance(node.op, JAXShapeTuple)
for inp in node.inputs
if inp in fgraph_inputs
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
]
self.scalar_shape_inputs = tuple(
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
)

return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)

def jit_compile(self, fn):
import jax

# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
return jax.jit(fn, static_argnums=static_argnums)
jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)

if not self.scalar_shape_inputs:
return jit_fn

def convert_scalar_shape_inputs(*args):
new_args = [
int(arg) if i in self.scalar_shape_inputs else arg
for i, arg in enumerate(args)
]
return jit_fn(*new_args)

return convert_scalar_shape_inputs

def create_thunk_inputs(self, storage_map):
from pytensor.link.jax.dispatch import jax_typify
Expand Down
54 changes: 47 additions & 7 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,15 +863,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)

def test_random_scalar_shape_input(self):
dim0 = pt.scalar("dim0", dtype=int)
dim1 = pt.scalar("dim1", dtype=int)

out = pt.random.normal(0, 1, size=dim0)
jax_fn = compile_random_function([dim0], out)
assert jax_fn(np.array(2)).shape == (2,)
assert jax_fn(np.array(3)).shape == (3,)

out = pt.random.normal(0, 1, size=[dim0, dim1])
jax_fn = compile_random_function([dim0, dim1], out)
assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)

@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
raises=TypeError, reason="Cannot convert scalar input to integer"
)
def test_random_concrete_shape_graph_input(self):
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)
def test_random_scalar_shape_input_not_supported(self):
dim = pt.scalar("dim", dtype=int)
out1 = pt.random.normal(0, 1, size=dim)
# An operation that wouldn't work if we replaced 0d array by integer
out2 = dim[...].set(1)
jax_fn = compile_random_function([dim], [out1, out2])

res1, res2 = jax_fn(np.array(2))
assert res1.shape == (2,)
assert res2 == 1

@pytest.mark.xfail(
raises=TypeError, reason="Cannot convert scalar input to integer"
)
def test_random_scalar_shape_input_not_supported2(self):
dim = pt.scalar("dim", dtype=int)
# This could theoretically be supported
# but would require knowing that * 2 is a safe operation for a python integer
out = pt.random.normal(0, 1, size=dim * 2)
jax_fn = compile_random_function([dim], out)
assert jax_fn(np.array(2)).shape == (4,)

@pytest.mark.xfail(
raises=TypeError, reason="Cannot convert tensor input to shape tuple"
)
def test_random_vector_shape_graph_input(self):
shape = pt.vector("shape", shape=(2,), dtype=int)
out = pt.random.normal(0, 1, size=shape)

jax_fn = compile_random_function([shape], out)
assert jax_fn(np.array([2, 3])).shape == (2, 3)
assert jax_fn(np.array([4, 5])).shape == (4, 5)

def test_constant_shape_after_graph_rewriting(self):
size = pt.vector("size", shape=(2,), dtype=int)
Expand Down

0 comments on commit 9b9bcba

Please sign in to comment.