25
25
26
26
27
27
@pytest .mark .parametrize (
28
- "comparison_op, exp_logp_true, exp_logp_false" ,
28
+ "comparison_op, exp_logp_true, exp_logp_false, inputs " ,
29
29
[
30
- ((pt .lt , pt .le ), "logcdf" , "logsf" ),
31
- ((pt .gt , pt .ge ), "logsf" , "logcdf" ),
30
+ ((pt .lt , pt .le ), "logcdf" , "logsf" , (pt .random .normal (0 , 1 ), 0.5 )),
31
+ ((pt .gt , pt .ge ), "logsf" , "logcdf" , (pt .random .normal (0 , 1 ), 0.5 )),
32
+ ((pt .lt , pt .le ), "logsf" , "logcdf" , (0.5 , pt .random .normal (0 , 1 ))),
33
+ ((pt .gt , pt .ge ), "logcdf" , "logsf" , (0.5 , pt .random .normal (0 , 1 ))),
32
34
],
33
35
)
34
- def test_continuous_rv_comparison (comparison_op , exp_logp_true , exp_logp_false ):
35
- x_rv = pt .random .normal (0 , 1 )
36
+ def test_continuous_rv_comparison (comparison_op , exp_logp_true , exp_logp_false , inputs ):
36
37
for op in comparison_op :
37
- comp_x_rv = op (x_rv , 0.5 )
38
+ comp_x_rv = op (* inputs )
38
39
39
40
comp_x_vv = comp_x_rv .clone ()
40
41
@@ -49,33 +50,45 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
49
50
50
51
51
52
@pytest .mark .parametrize (
52
- "comparison_op, exp_logp_true, exp_logp_false" ,
53
+ "comparison_op, exp_logp_true, exp_logp_false, inputs " ,
53
54
[
54
55
(
55
56
pt .lt ,
56
57
lambda x : st .poisson (2 ).logcdf (x - 1 ),
57
58
lambda x : np .logaddexp (st .poisson (2 ).logsf (x ), st .poisson (2 ).logpmf (x )),
59
+ (pt .random .poisson (2 ), 3 ),
58
60
),
59
61
(
60
62
pt .ge ,
61
63
lambda x : np .logaddexp (st .poisson (2 ).logsf (x ), st .poisson (2 ).logpmf (x )),
62
64
lambda x : st .poisson (2 ).logcdf (x - 1 ),
65
+ (pt .random .poisson (2 ), 3 ),
63
66
),
67
+ (pt .gt , st .poisson (2 ).logsf , st .poisson (2 ).logcdf , (pt .random .poisson (2 ), 3 )),
68
+ (pt .le , st .poisson (2 ).logcdf , st .poisson (2 ).logsf , (pt .random .poisson (2 ), 3 )),
64
69
(
65
- pt .gt ,
70
+ pt .lt ,
66
71
st .poisson (2 ).logsf ,
67
72
st .poisson (2 ).logcdf ,
73
+ (3 , pt .random .poisson (2 )),
74
+ ),
75
+ (pt .ge , st .poisson (2 ).logcdf , st .poisson (2 ).logsf , (3 , pt .random .poisson (2 ))),
76
+ (
77
+ pt .gt ,
78
+ lambda x : st .poisson (2 ).logcdf (x - 1 ),
79
+ lambda x : np .logaddexp (st .poisson (2 ).logsf (x ), st .poisson (2 ).logpmf (x )),
80
+ (3 , pt .random .poisson (2 )),
68
81
),
69
82
(
70
83
pt .le ,
71
- st .poisson (2 ).logcdf ,
72
- st .poisson (2 ).logsf ,
84
+ lambda x : np .logaddexp (st .poisson (2 ).logsf (x ), st .poisson (2 ).logpmf (x )),
85
+ lambda x : st .poisson (2 ).logcdf (x - 1 ),
86
+ (3 , pt .random .poisson (2 )),
73
87
),
74
88
],
75
89
)
76
- def test_discrete_rv_comparison (comparison_op , exp_logp_true , exp_logp_false ):
77
- x_rv = pt .random .poisson (2 )
78
- cens_x_rv = comparison_op (x_rv , 3 )
90
+ def test_discrete_rv_comparison (inputs , comparison_op , exp_logp_true , exp_logp_false ):
91
+ cens_x_rv = comparison_op (* inputs )
79
92
80
93
cens_x_vv = cens_x_rv .clone ()
81
94
0 commit comments