|
5 | 5 | import pytensor
|
6 | 6 | import pytensor.tensor as pt
|
7 | 7 | import pytensor.tensor.random.basic as ptr
|
| 8 | +from pytensor import clone_replace |
8 | 9 | from pytensor.compile.function import function
|
9 | 10 | from pytensor.compile.sharedvalue import SharedVariable, shared
|
10 | 11 | from pytensor.graph.basic import Constant
|
|
26 | 27 | from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
|
27 | 28 |
|
28 | 29 |
|
29 |
| -def compile_random_function(*args, **kwargs): |
| 30 | +def compile_random_function(*args, mode="JAX", **kwargs): |
30 | 31 | with pytest.warns(
|
31 | 32 | UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
|
32 | 33 | ):
|
33 |
| - return function(*args, **kwargs) |
| 34 | + return function(*args, mode=mode, **kwargs) |
34 | 35 |
|
35 | 36 |
|
36 | 37 | def test_random_RandomStream():
|
@@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input():
|
896 | 897 | out = pt.random.normal(0, 1, size=size_pt, rng=rng)
|
897 | 898 | jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
|
898 | 899 | assert jax_fn(10).shape == (10,)
|
| 900 | + |
| 901 | + |
| 902 | +def test_constant_shape_after_graph_rewriting(): |
| 903 | + size = pt.vector("size", shape=(2,), dtype=int) |
| 904 | + x = pt.random.normal(size=size) |
| 905 | + assert x.type.shape == (None, None) |
| 906 | + |
| 907 | + with pytest.raises(TypeError): |
| 908 | + compile_random_function([size], x)([2, 5]) |
| 909 | + |
| 910 | + # Rebuild with strict=False so output type is not updated |
| 911 | + # This reflects cases where size is constant folded during rewrites but the RV node is not recreated |
| 912 | + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) |
| 913 | + assert new_x.type.shape == (None, None) |
| 914 | + assert compile_random_function([], new_x)().shape == (2, 5) |
| 915 | + |
| 916 | + # Rebuild with strict=True, so output type is updated |
| 917 | + # This uses a different path in the dispatch implementation |
| 918 | + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) |
| 919 | + assert new_x.type.shape == (2, 5) |
| 920 | + assert compile_random_function([], new_x)().shape == (2, 5) |
0 commit comments