@@ -145,13 +145,7 @@ def test_geometric(self):
145
145
)
146
146
147
147
def test_hypergeometric (self ):
148
- def modified_scipy_hypergeom_logpmf (value , N , k , n ):
149
- # Convert nan to -np.inf
150
- original_res = st .hypergeom .logpmf (value , N , k , n )
151
- return original_res if not np .isnan (original_res ) else - np .inf
152
-
153
148
def modified_scipy_hypergeom_logcdf (value , N , k , n ):
154
- # Convert nan to -np.inf
155
149
original_res = st .hypergeom .logcdf (value , N , k , n )
156
150
157
151
# Correct for scipy bug in logcdf method (see https://github.com/scipy/scipy/issues/13280)
@@ -160,24 +154,27 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
160
154
if np .all (np .isnan (pmfs )):
161
155
original_res = np .nan
162
156
163
- return original_res if not np .isnan (original_res ) else - np .inf
157
+ return original_res
158
+
159
+ N_domain = Domain ([0 , 10 , 20 , 30 , np .inf ], dtype = "int64" )
160
+ n_domain = k_domain = Domain ([0 , 1 , 2 , 3 , np .inf ], dtype = "int64" )
164
161
165
162
check_logp (
166
163
pm .HyperGeometric ,
167
164
Nat ,
168
- {"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
169
- modified_scipy_hypergeom_logpmf ,
165
+ {"N" : N_domain , "k" : k_domain , "n" : n_domain },
166
+ lambda value , N , k , n : st . hypergeom . logpmf ( value , N , k , n ) ,
170
167
)
171
168
check_logcdf (
172
169
pm .HyperGeometric ,
173
170
Nat ,
174
- {"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
171
+ {"N" : N_domain , "k" : k_domain , "n" : n_domain },
175
172
modified_scipy_hypergeom_logcdf ,
176
173
)
177
174
check_selfconsistency_discrete_logcdf (
178
175
pm .HyperGeometric ,
179
176
Nat ,
180
- {"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
177
+ {"N" : N_domain , "k" : k_domain , "n" : n_domain },
181
178
)
182
179
183
180
@pytest .mark .xfail (
@@ -535,15 +532,17 @@ def test_categorical_p_not_normalized_symbolic(self):
535
532
536
533
@pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
537
534
def test_orderedlogistic (self , n ):
538
- with warnings .catch_warnings ():
539
- warnings .filterwarnings ("ignore" , "invalid value encountered in log" , RuntimeWarning )
540
- warnings .filterwarnings ("ignore" , "divide by zero encountered in log" , RuntimeWarning )
541
- check_logp (
542
- pm .OrderedLogistic ,
543
- Domain (range (n ), dtype = "int64" , edges = (None , None )),
544
- {"eta" : R , "cutpoints" : Vector (R , n - 1 )},
545
- lambda value , eta , cutpoints : orderedlogistic_logpdf (value , eta , cutpoints ),
546
- )
535
+ cutpoints_domain = Vector (R , n - 1 )
536
+ # Filter out invalid non-monotonic values
537
+ cutpoints_domain .vals = [v for v in cutpoints_domain .vals if np .all (np .diff (v ) > 0 )]
538
+ assert len (cutpoints_domain .vals ) > 0
539
+
540
+ check_logp (
541
+ pm .OrderedLogistic ,
542
+ Domain (range (n ), dtype = "int64" , edges = (None , None )),
543
+ {"eta" : R , "cutpoints" : cutpoints_domain },
544
+ lambda value , eta , cutpoints : orderedlogistic_logpdf (value , eta , cutpoints ),
545
+ )
547
546
548
547
@pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
549
548
def test_orderedprobit (self , n ):
0 commit comments