Skip to content

Commit 2b7f95c

Browse files
committed
Fix bug in assert_size_argument_jax_compatible
1 parent bb7d70f commit 2b7f95c

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ def assert_size_argument_jax_compatible(node):
4545
4646
"""
4747
size = node.inputs[1]
48-
size_op = size.owner.op
49-
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
48+
size_node = size.owner
49+
if (size_node is not None) and (
50+
not isinstance(size_node.op, (Shape, Shape_i, JAXShapeTuple))
51+
):
5052
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
5153

5254

tests/link/jax/test_random.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,18 @@ def test_random_concrete_shape():
693693
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
694694

695695

696+
def test_random_concrete_shape_from_param():
697+
rng = shared(np.random.RandomState(123))
698+
x_at = at.dmatrix()
699+
out = at.random.normal(x_at, 1, rng=rng)
700+
with pytest.warns(
701+
UserWarning,
702+
match="The RandomType SharedVariables \[.+\] will not be used"
703+
):
704+
jax_fn = function([x_at], out, mode=jax_mode)
705+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
706+
707+
696708
def test_random_concrete_shape_subtensor():
697709
"""JAX should compile when a concrete value is passed for the `size` parameter.
698710

0 commit comments

Comments
 (0)