1
- import re
2
-
3
1
import numpy as np
4
2
import pytest
5
3
import scipy .stats as stats
22
20
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
23
21
24
22
23
+ def random_function (* args , ** kwargs ):
24
+ with pytest .warns (
25
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
26
+ ):
27
+ return function (* args , ** kwargs )
28
+
29
+
25
30
def test_random_RandomStream ():
26
31
"""Two successive calls of a compiled graph using `RandomStream` should
27
32
return different values.
@@ -30,11 +35,7 @@ def test_random_RandomStream():
30
35
srng = RandomStream (seed = 123 )
31
36
out = srng .normal () - srng .normal ()
32
37
33
- with pytest .warns (
34
- UserWarning ,
35
- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
36
- ):
37
- fn = function ([], out , mode = jax_mode )
38
+ fn = random_function ([], out , mode = jax_mode )
38
39
jax_res_1 = fn ()
39
40
jax_res_2 = fn ()
40
41
@@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
47
48
rng = shared (original_value , name = "original_rng" , borrow = False )
48
49
next_rng , x = at .random .normal (name = "x" , rng = rng ).owner .outputs
49
50
50
- with pytest .warns (
51
- UserWarning ,
52
- match = re .escape (
53
- "The RandomType SharedVariables [original_rng] will not be used"
54
- ),
55
- ):
56
- f = pytensor .function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
51
+ f = random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
57
52
assert f () != f ()
58
53
59
54
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -83,17 +78,14 @@ def test_random_updates_input_storage_order():
83
78
84
79
# This function replaces inp by input_shared in the update expression
85
80
# This is what caused the RNG to appear later than inp_shared in the input_storage
86
- with pytest .warns (
87
- UserWarning ,
88
- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
89
- ):
90
- fn = pytensor .function (
91
- inputs = [],
92
- outputs = [],
93
- updates = {inp_shared : inp_update },
94
- givens = {inp : inp_shared },
95
- mode = "JAX" ,
96
- )
81
+
82
+ fn = random_function (
83
+ inputs = [],
84
+ outputs = [],
85
+ updates = {inp_shared : inp_update },
86
+ givens = {inp : inp_shared },
87
+ mode = "JAX" ,
88
+ )
97
89
fn ()
98
90
np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
99
91
fn ()
@@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
457
449
else :
458
450
rng = shared (np .random .RandomState (29402 ))
459
451
g = rv_op (* dist_params , size = (10_000 ,) + base_size , rng = rng )
460
- g_fn = function (dist_params , g , mode = jax_mode )
452
+ g_fn = random_function (dist_params , g , mode = jax_mode )
461
453
samples = g_fn (
462
454
* [
463
455
i .tag .test_value
@@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
481
473
def test_random_bernoulli (size ):
482
474
rng = shared (np .random .RandomState (123 ))
483
475
g = at .random .bernoulli (0.5 , size = (1000 ,) + size , rng = rng )
484
- g_fn = function ([], g , mode = jax_mode )
476
+ g_fn = random_function ([], g , mode = jax_mode )
485
477
samples = g_fn ()
486
478
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
487
479
@@ -492,7 +484,7 @@ def test_random_mvnormal():
492
484
mu = np .ones (4 )
493
485
cov = np .eye (4 )
494
486
g = at .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
495
- g_fn = function ([], g , mode = jax_mode )
487
+ g_fn = random_function ([], g , mode = jax_mode )
496
488
samples = g_fn ()
497
489
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
498
490
@@ -507,7 +499,7 @@ def test_random_mvnormal():
507
499
def test_random_dirichlet (parameter , size ):
508
500
rng = shared (np .random .RandomState (123 ))
509
501
g = at .random .dirichlet (parameter , size = (1000 ,) + size , rng = rng )
510
- g_fn = function ([], g , mode = jax_mode )
502
+ g_fn = random_function ([], g , mode = jax_mode )
511
503
samples = g_fn ()
512
504
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
513
505
@@ -517,29 +509,29 @@ def test_random_choice():
517
509
num_samples = 10000
518
510
rng = shared (np .random .RandomState (123 ))
519
511
g = at .random .choice (np .arange (4 ), size = num_samples , rng = rng )
520
- g_fn = function ([], g , mode = jax_mode )
512
+ g_fn = random_function ([], g , mode = jax_mode )
521
513
samples = g_fn ()
522
514
np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
523
515
524
516
# `replace=False` produces unique results
525
517
rng = shared (np .random .RandomState (123 ))
526
518
g = at .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
527
- g_fn = function ([], g , mode = jax_mode )
519
+ g_fn = random_function ([], g , mode = jax_mode )
528
520
samples = g_fn ()
529
521
assert len (np .unique (samples )) == 99
530
522
531
523
# We can pass an array with probabilities
532
524
rng = shared (np .random .RandomState (123 ))
533
525
g = at .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
534
- g_fn = function ([], g , mode = jax_mode )
526
+ g_fn = random_function ([], g , mode = jax_mode )
535
527
samples = g_fn ()
536
528
np .testing .assert_allclose (samples , np .zeros (10 ))
537
529
538
530
539
531
def test_random_categorical ():
540
532
rng = shared (np .random .RandomState (123 ))
541
533
g = at .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
542
- g_fn = function ([], g , mode = jax_mode )
534
+ g_fn = random_function ([], g , mode = jax_mode )
543
535
samples = g_fn ()
544
536
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
545
537
@@ -548,7 +540,7 @@ def test_random_permutation():
548
540
array = np .arange (4 )
549
541
rng = shared (np .random .RandomState (123 ))
550
542
g = at .random .permutation (array , rng = rng )
551
- g_fn = function ([], g , mode = jax_mode )
543
+ g_fn = random_function ([], g , mode = jax_mode )
552
544
permuted = g_fn ()
553
545
with pytest .raises (AssertionError ):
554
546
np .testing .assert_allclose (array , permuted )
@@ -558,7 +550,7 @@ def test_random_geometric():
558
550
rng = shared (np .random .RandomState (123 ))
559
551
p = np .array ([0.3 , 0.7 ])
560
552
g = at .random .geometric (p , size = (10_000 , 2 ), rng = rng )
561
- g_fn = function ([], g , mode = jax_mode )
553
+ g_fn = random_function ([], g , mode = jax_mode )
562
554
samples = g_fn ()
563
555
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
564
556
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -569,7 +561,7 @@ def test_negative_binomial():
569
561
n = np .array ([10 , 40 ])
570
562
p = np .array ([0.3 , 0.7 ])
571
563
g = at .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
572
- g_fn = function ([], g , mode = jax_mode )
564
+ g_fn = random_function ([], g , mode = jax_mode )
573
565
samples = g_fn ()
574
566
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
575
567
np .testing .assert_allclose (
@@ -583,7 +575,7 @@ def test_binomial():
583
575
n = np .array ([10 , 40 ])
584
576
p = np .array ([0.3 , 0.7 ])
585
577
g = at .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
586
- g_fn = function ([], g , mode = jax_mode )
578
+ g_fn = random_function ([], g , mode = jax_mode )
587
579
samples = g_fn ()
588
580
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
589
581
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -598,7 +590,7 @@ def test_beta_binomial():
598
590
a = np .array ([1.5 , 13 ])
599
591
b = np .array ([0.5 , 9 ])
600
592
g = at .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
601
- g_fn = function ([], g , mode = jax_mode )
593
+ g_fn = random_function ([], g , mode = jax_mode )
602
594
samples = g_fn ()
603
595
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
604
596
np .testing .assert_allclose (
@@ -616,7 +608,7 @@ def test_multinomial():
616
608
n = np .array ([10 , 40 ])
617
609
p = np .array ([[0.3 , 0.7 , 0.0 ], [0.1 , 0.4 , 0.5 ]])
618
610
g = at .random .multinomial (n , p , size = (10_000 , 2 ), rng = rng )
619
- g_fn = function ([], g , mode = jax_mode )
611
+ g_fn = random_function ([], g , mode = jax_mode )
620
612
samples = g_fn ()
621
613
np .testing .assert_allclose (samples .mean (axis = 0 ), n [..., None ] * p , rtol = 0.1 )
622
614
np .testing .assert_allclose (
@@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle():
632
624
mu = np .array ([- 30 , 40 ])
633
625
kappa = np .array ([100 , 10 ])
634
626
g = at .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
635
- g_fn = function ([], g , mode = jax_mode )
627
+ g_fn = random_function ([], g , mode = jax_mode )
636
628
samples = g_fn ()
637
629
np .testing .assert_allclose (
638
630
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -678,7 +670,10 @@ def rng_fn(cls, rng, size):
678
670
fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
679
671
680
672
with pytest .raises (NotImplementedError ):
681
- compare_jax_and_py (fgraph , [])
673
+ with pytest .warns (
674
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
675
+ ):
676
+ compare_jax_and_py (fgraph , [])
682
677
683
678
684
679
def test_random_custom_implementation ():
@@ -709,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters):
709
704
rng = shared (np .random .RandomState (123 ))
710
705
out = nonexistentrv (rng = rng )
711
706
fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
712
- compare_jax_and_py (fgraph , [])
707
+ with pytest .warns (
708
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
709
+ ):
710
+ compare_jax_and_py (fgraph , [])
713
711
714
712
715
713
def test_random_concrete_shape ():
@@ -726,19 +724,15 @@ def test_random_concrete_shape():
726
724
rng = shared (np .random .RandomState (123 ))
727
725
x_at = at .dmatrix ()
728
726
out = at .random .normal (0 , 1 , size = x_at .shape , rng = rng )
729
- jax_fn = function ([x_at ], out , mode = jax_mode )
727
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
730
728
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
731
729
732
730
733
731
def test_random_concrete_shape_from_param ():
734
732
rng = shared (np .random .RandomState (123 ))
735
733
x_at = at .dmatrix ()
736
734
out = at .random .normal (x_at , 1 , rng = rng )
737
- with pytest .warns (
738
- UserWarning ,
739
- match = "The RandomType SharedVariables \[.+\] will not be used"
740
- ):
741
- jax_fn = function ([x_at ], out , mode = jax_mode )
735
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
742
736
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
743
737
744
738
@@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor():
757
751
rng = shared (np .random .RandomState (123 ))
758
752
x_at = at .dmatrix ()
759
753
out = at .random .normal (0 , 1 , size = x_at .shape [1 ], rng = rng )
760
- jax_fn = function ([x_at ], out , mode = jax_mode )
754
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
761
755
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
762
756
763
757
@@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
773
767
rng = shared (np .random .RandomState (123 ))
774
768
x_at = at .dmatrix ()
775
769
out = at .random .normal (0 , 1 , size = (x_at .shape [0 ],), rng = rng )
776
- jax_fn = function ([x_at ], out , mode = jax_mode )
770
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
777
771
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
778
772
779
773
@@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input():
784
778
rng = shared (np .random .RandomState (123 ))
785
779
size_at = at .scalar ()
786
780
out = at .random .normal (0 , 1 , size = size_at , rng = rng )
787
- jax_fn = function ([size_at ], out , mode = jax_mode )
781
+ jax_fn = random_function ([size_at ], out , mode = jax_mode )
788
782
assert jax_fn (10 ).shape == (10 ,)
0 commit comments