Skip to content

Commit 53b00ea

Browse files
committed
Intercept UserWarning on JAX random function tests
1 parent 93bfa1b commit 53b00ea

File tree

1 file changed

+45
-51
lines changed

1 file changed

+45
-51
lines changed

tests/link/jax/test_random.py

Lines changed: 45 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import re
2-
31
import numpy as np
42
import pytest
53
import scipy.stats as stats
@@ -22,6 +20,13 @@
2220
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2321

2422

23+
def random_function(*args, **kwargs):
24+
with pytest.warns(
25+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
26+
):
27+
return function(*args, **kwargs)
28+
29+
2530
def test_random_RandomStream():
2631
"""Two successive calls of a compiled graph using `RandomStream` should
2732
return different values.
@@ -30,11 +35,7 @@ def test_random_RandomStream():
3035
srng = RandomStream(seed=123)
3136
out = srng.normal() - srng.normal()
3237

33-
with pytest.warns(
34-
UserWarning,
35-
match=r"The RandomType SharedVariables \[.+\] will not be used",
36-
):
37-
fn = function([], out, mode=jax_mode)
38+
fn = random_function([], out, mode=jax_mode)
3839
jax_res_1 = fn()
3940
jax_res_2 = fn()
4041

@@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
4748
rng = shared(original_value, name="original_rng", borrow=False)
4849
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
4950

50-
with pytest.warns(
51-
UserWarning,
52-
match=re.escape(
53-
"The RandomType SharedVariables [original_rng] will not be used"
54-
),
55-
):
56-
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
51+
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
5752
assert f() != f()
5853

5954
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -83,17 +78,14 @@ def test_random_updates_input_storage_order():
8378

8479
# This function replaces inp by input_shared in the update expression
8580
# This is what caused the RNG to appear later than inp_shared in the input_storage
86-
with pytest.warns(
87-
UserWarning,
88-
match=r"The RandomType SharedVariables \[.+\] will not be used",
89-
):
90-
fn = pytensor.function(
91-
inputs=[],
92-
outputs=[],
93-
updates={inp_shared: inp_update},
94-
givens={inp: inp_shared},
95-
mode="JAX",
96-
)
81+
82+
fn = random_function(
83+
inputs=[],
84+
outputs=[],
85+
updates={inp_shared: inp_update},
86+
givens={inp: inp_shared},
87+
mode="JAX",
88+
)
9789
fn()
9890
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
9991
fn()
@@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
457449
else:
458450
rng = shared(np.random.RandomState(29402))
459451
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
460-
g_fn = function(dist_params, g, mode=jax_mode)
452+
g_fn = random_function(dist_params, g, mode=jax_mode)
461453
samples = g_fn(
462454
*[
463455
i.tag.test_value
@@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
481473
def test_random_bernoulli(size):
482474
rng = shared(np.random.RandomState(123))
483475
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
484-
g_fn = function([], g, mode=jax_mode)
476+
g_fn = random_function([], g, mode=jax_mode)
485477
samples = g_fn()
486478
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
487479

@@ -492,7 +484,7 @@ def test_random_mvnormal():
492484
mu = np.ones(4)
493485
cov = np.eye(4)
494486
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
495-
g_fn = function([], g, mode=jax_mode)
487+
g_fn = random_function([], g, mode=jax_mode)
496488
samples = g_fn()
497489
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
498490

@@ -507,7 +499,7 @@ def test_random_mvnormal():
507499
def test_random_dirichlet(parameter, size):
508500
rng = shared(np.random.RandomState(123))
509501
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
510-
g_fn = function([], g, mode=jax_mode)
502+
g_fn = random_function([], g, mode=jax_mode)
511503
samples = g_fn()
512504
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
513505

@@ -517,29 +509,29 @@ def test_random_choice():
517509
num_samples = 10000
518510
rng = shared(np.random.RandomState(123))
519511
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
520-
g_fn = function([], g, mode=jax_mode)
512+
g_fn = random_function([], g, mode=jax_mode)
521513
samples = g_fn()
522514
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
523515

524516
# `replace=False` produces unique results
525517
rng = shared(np.random.RandomState(123))
526518
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
527-
g_fn = function([], g, mode=jax_mode)
519+
g_fn = random_function([], g, mode=jax_mode)
528520
samples = g_fn()
529521
assert len(np.unique(samples)) == 99
530522

531523
# We can pass an array with probabilities
532524
rng = shared(np.random.RandomState(123))
533525
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
534-
g_fn = function([], g, mode=jax_mode)
526+
g_fn = random_function([], g, mode=jax_mode)
535527
samples = g_fn()
536528
np.testing.assert_allclose(samples, np.zeros(10))
537529

538530

539531
def test_random_categorical():
540532
rng = shared(np.random.RandomState(123))
541533
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
542-
g_fn = function([], g, mode=jax_mode)
534+
g_fn = random_function([], g, mode=jax_mode)
543535
samples = g_fn()
544536
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
545537

@@ -548,7 +540,7 @@ def test_random_permutation():
548540
array = np.arange(4)
549541
rng = shared(np.random.RandomState(123))
550542
g = at.random.permutation(array, rng=rng)
551-
g_fn = function([], g, mode=jax_mode)
543+
g_fn = random_function([], g, mode=jax_mode)
552544
permuted = g_fn()
553545
with pytest.raises(AssertionError):
554546
np.testing.assert_allclose(array, permuted)
@@ -558,7 +550,7 @@ def test_random_geometric():
558550
rng = shared(np.random.RandomState(123))
559551
p = np.array([0.3, 0.7])
560552
g = at.random.geometric(p, size=(10_000, 2), rng=rng)
561-
g_fn = function([], g, mode=jax_mode)
553+
g_fn = random_function([], g, mode=jax_mode)
562554
samples = g_fn()
563555
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
564556
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -569,7 +561,7 @@ def test_negative_binomial():
569561
n = np.array([10, 40])
570562
p = np.array([0.3, 0.7])
571563
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
572-
g_fn = function([], g, mode=jax_mode)
564+
g_fn = random_function([], g, mode=jax_mode)
573565
samples = g_fn()
574566
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
575567
np.testing.assert_allclose(
@@ -583,7 +575,7 @@ def test_binomial():
583575
n = np.array([10, 40])
584576
p = np.array([0.3, 0.7])
585577
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
586-
g_fn = function([], g, mode=jax_mode)
578+
g_fn = random_function([], g, mode=jax_mode)
587579
samples = g_fn()
588580
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
589581
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -598,7 +590,7 @@ def test_beta_binomial():
598590
a = np.array([1.5, 13])
599591
b = np.array([0.5, 9])
600592
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
601-
g_fn = function([], g, mode=jax_mode)
593+
g_fn = random_function([], g, mode=jax_mode)
602594
samples = g_fn()
603595
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
604596
np.testing.assert_allclose(
@@ -616,7 +608,7 @@ def test_multinomial():
616608
n = np.array([10, 40])
617609
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
618610
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
619-
g_fn = function([], g, mode=jax_mode)
611+
g_fn = random_function([], g, mode=jax_mode)
620612
samples = g_fn()
621613
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
622614
np.testing.assert_allclose(
@@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle():
632624
mu = np.array([-30, 40])
633625
kappa = np.array([100, 10])
634626
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
635-
g_fn = function([], g, mode=jax_mode)
627+
g_fn = random_function([], g, mode=jax_mode)
636628
samples = g_fn()
637629
np.testing.assert_allclose(
638630
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -678,7 +670,10 @@ def rng_fn(cls, rng, size):
678670
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
679671

680672
with pytest.raises(NotImplementedError):
681-
compare_jax_and_py(fgraph, [])
673+
with pytest.warns(
674+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
675+
):
676+
compare_jax_and_py(fgraph, [])
682677

683678

684679
def test_random_custom_implementation():
@@ -709,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters):
709704
rng = shared(np.random.RandomState(123))
710705
out = nonexistentrv(rng=rng)
711706
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
712-
compare_jax_and_py(fgraph, [])
707+
with pytest.warns(
708+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
709+
):
710+
compare_jax_and_py(fgraph, [])
713711

714712

715713
def test_random_concrete_shape():
@@ -726,19 +724,15 @@ def test_random_concrete_shape():
726724
rng = shared(np.random.RandomState(123))
727725
x_at = at.dmatrix()
728726
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
729-
jax_fn = function([x_at], out, mode=jax_mode)
727+
jax_fn = random_function([x_at], out, mode=jax_mode)
730728
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
731729

732730

733731
def test_random_concrete_shape_from_param():
734732
rng = shared(np.random.RandomState(123))
735733
x_at = at.dmatrix()
736734
out = at.random.normal(x_at, 1, rng=rng)
737-
with pytest.warns(
738-
UserWarning,
739-
match="The RandomType SharedVariables \[.+\] will not be used"
740-
):
741-
jax_fn = function([x_at], out, mode=jax_mode)
735+
jax_fn = random_function([x_at], out, mode=jax_mode)
742736
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
743737

744738

@@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor():
757751
rng = shared(np.random.RandomState(123))
758752
x_at = at.dmatrix()
759753
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
760-
jax_fn = function([x_at], out, mode=jax_mode)
754+
jax_fn = random_function([x_at], out, mode=jax_mode)
761755
assert jax_fn(np.ones((2, 3))).shape == (3,)
762756

763757

@@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
773767
rng = shared(np.random.RandomState(123))
774768
x_at = at.dmatrix()
775769
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
776-
jax_fn = function([x_at], out, mode=jax_mode)
770+
jax_fn = random_function([x_at], out, mode=jax_mode)
777771
assert jax_fn(np.ones((2, 3))).shape == (2,)
778772

779773

@@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input():
784778
rng = shared(np.random.RandomState(123))
785779
size_at = at.scalar()
786780
out = at.random.normal(0, 1, size=size_at, rng=rng)
787-
jax_fn = function([size_at], out, mode=jax_mode)
781+
jax_fn = random_function([size_at], out, mode=jax_mode)
788782
assert jax_fn(10).shape == (10,)

0 commit comments

Comments
 (0)