Skip to content

Commit bb028ae

Browse files
committed
Inline static size inputs in JAX implementation of RandomVariables
This gets around some limitations in JAX jitting system
1 parent 863efc0 commit bb028ae

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99

1010
import pytensor.tensor.random.basic as ptr
11+
from pytensor.graph import Constant
1112
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1213
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1314
from pytensor.tensor.shape import Shape, Shape_i
@@ -91,15 +92,26 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
9192
"""JAX implementation of random variables."""
9293
rv = node.outputs[1]
9394
out_dtype = rv.type.dtype
94-
out_size = rv.type.shape
95+
static_shape = rv.type.shape
9596

9697
batch_ndim = op.batch_ndim(node)
97-
out_size = node.default_output().type.shape[:batch_ndim]
98+
99+
# Try to pass static size directly to JAX
100+
static_size = static_shape[:batch_ndim]
101+
if None in static_size:
102+
# Sometimes size can be constant folded during rewrites,
103+
# without the RandomVariable node being updated with new static types
104+
size_param = node.inputs[1]
105+
if isinstance(size_param, Constant):
106+
size_tuple = tuple(size_param.data)
107+
# PyTensor uses empty size to represent size = None
108+
if len(size_tuple):
109+
static_size = tuple(size_param.data)
98110

99111
# If one dimension has unknown size, either the size is determined
100112
# by a `Shape` operator in which case JAX will compile, or it is
101113
# not and we fail gracefully.
102-
if None in out_size:
114+
if None in static_size:
103115
assert_size_argument_jax_compatible(node)
104116

105117
def sample_fn(rng, size, dtype, *parameters):
@@ -111,7 +123,9 @@ def sample_fn(rng, size, dtype, *parameters):
111123
else:
112124

113125
def sample_fn(rng, size, dtype, *parameters):
114-
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
126+
return jax_sample_fn(op, node=node)(
127+
rng, static_size, out_dtype, *parameters
128+
)
115129

116130
return sample_fn
117131

tests/link/jax/test_random.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor
66
import pytensor.tensor as pt
77
import pytensor.tensor.random.basic as ptr
8+
from pytensor import clone_replace
89
from pytensor.compile.function import function
910
from pytensor.compile.sharedvalue import SharedVariable, shared
1011
from pytensor.graph.basic import Constant
@@ -26,11 +27,11 @@
2627
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2728

2829

29-
def compile_random_function(*args, **kwargs):
30+
def compile_random_function(*args, mode="JAX", **kwargs):
3031
with pytest.warns(
3132
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
3233
):
33-
return function(*args, **kwargs)
34+
return function(*args, mode=mode, **kwargs)
3435

3536

3637
def test_random_RandomStream():
@@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input():
896897
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
897898
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
898899
assert jax_fn(10).shape == (10,)
900+
901+
902+
def test_constant_shape_after_graph_rewriting():
903+
size = pt.vector("size", shape=(2,), dtype=int)
904+
x = pt.random.normal(size=size)
905+
assert x.type.shape == (None, None)
906+
907+
with pytest.raises(TypeError):
908+
compile_random_function([size], x)([2, 5])
909+
910+
# Rebuild with strict=False so output type is not updated
911+
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
912+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
913+
assert new_x.type.shape == (None, None)
914+
assert compile_random_function([], new_x)().shape == (2, 5)
915+
916+
# Rebuild with strict=True, so output type is updated
917+
# This uses a different path in the dispatch implementation
918+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
919+
assert new_x.type.shape == (2, 5)
920+
assert compile_random_function([], new_x)().shape == (2, 5)

0 commit comments

Comments
 (0)