Skip to content

Commit a108051

Browse files
committed
Do not test invalid parameter combinations
1 parent 457421b commit a108051

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

tests/distributions/test_continuous.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ def laplace_asymmetric_logpdf(value, kappa, b, mu):
159159
return lPx
160160

161161

162-
def beta_mu_sigma(value, mu, sigma):
163-
kappa = mu * (1 - mu) / sigma**2 - 1
164-
if kappa > 0:
165-
return st.beta.logpdf(value, mu * kappa, (1 - mu) * kappa)
166-
else:
167-
return -np.inf
168-
169-
170162
class TestMatchesScipy:
171163
def test_uniform(self):
172164
check_logp(
@@ -367,10 +359,18 @@ def test_beta_logp(self):
367359
{"alpha": Rplus, "beta": Rplus},
368360
lambda value, alpha, beta: st.beta.logpdf(value, alpha, beta),
369361
)
362+
363+
def beta_mu_sigma(value, mu, sigma):
364+
kappa = mu * (1 - mu) / sigma**2 - 1
365+
return st.beta.logpdf(value, mu * kappa, (1 - mu) * kappa)
366+
367+
# The mu/sigma parametrization is not always valid
368+
safe_mu_domain = Domain([0, 0.3, 0.5, 0.8, 1])
369+
safe_sigma_domain = Domain([0, 0.05, 0.1, np.inf])
370370
check_logp(
371371
pm.Beta,
372372
Unit,
373-
{"mu": Unit, "sigma": Rplus},
373+
{"mu": safe_mu_domain, "sigma": safe_sigma_domain},
374374
beta_mu_sigma,
375375
)
376376

tests/distributions/test_discrete.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,7 @@ def test_geometric(self):
145145
)
146146

147147
def test_hypergeometric(self):
148-
def modified_scipy_hypergeom_logpmf(value, N, k, n):
149-
# Convert nan to -np.inf
150-
original_res = st.hypergeom.logpmf(value, N, k, n)
151-
return original_res if not np.isnan(original_res) else -np.inf
152-
153148
def modified_scipy_hypergeom_logcdf(value, N, k, n):
154-
# Convert nan to -np.inf
155149
original_res = st.hypergeom.logcdf(value, N, k, n)
156150

157151
# Correct for scipy bug in logcdf method (see https://github.com/scipy/scipy/issues/13280)
@@ -160,24 +154,27 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
160154
if np.all(np.isnan(pmfs)):
161155
original_res = np.nan
162156

163-
return original_res if not np.isnan(original_res) else -np.inf
157+
return original_res
158+
159+
N_domain = Domain([0, 10, 20, 30, np.inf], dtype="int64")
160+
n_domain = k_domain = Domain([0, 1, 2, 3, np.inf], dtype="int64")
164161

165162
check_logp(
166163
pm.HyperGeometric,
167164
Nat,
168-
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
169-
modified_scipy_hypergeom_logpmf,
165+
{"N": N_domain, "k": k_domain, "n": n_domain},
166+
lambda value, N, k, n: st.hypergeom.logpmf(value, N, k, n),
170167
)
171168
check_logcdf(
172169
pm.HyperGeometric,
173170
Nat,
174-
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
171+
{"N": N_domain, "k": k_domain, "n": n_domain},
175172
modified_scipy_hypergeom_logcdf,
176173
)
177174
check_selfconsistency_discrete_logcdf(
178175
pm.HyperGeometric,
179176
Nat,
180-
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
177+
{"N": N_domain, "k": k_domain, "n": n_domain},
181178
)
182179

183180
@pytest.mark.xfail(
@@ -535,15 +532,17 @@ def test_categorical_p_not_normalized_symbolic(self):
535532

536533
@pytest.mark.parametrize("n", [2, 3, 4])
537534
def test_orderedlogistic(self, n):
538-
with warnings.catch_warnings():
539-
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)
540-
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
541-
check_logp(
542-
pm.OrderedLogistic,
543-
Domain(range(n), dtype="int64", edges=(None, None)),
544-
{"eta": R, "cutpoints": Vector(R, n - 1)},
545-
lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints),
546-
)
535+
cutpoints_domain = Vector(R, n - 1)
536+
# Filter out invalid non-monotonic values
537+
cutpoints_domain.vals = [v for v in cutpoints_domain.vals if np.all(np.diff(v) > 0)]
538+
assert len(cutpoints_domain.vals) > 0
539+
540+
check_logp(
541+
pm.OrderedLogistic,
542+
Domain(range(n), dtype="int64", edges=(None, None)),
543+
{"eta": R, "cutpoints": cutpoints_domain},
544+
lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints),
545+
)
547546

548547
@pytest.mark.parametrize("n", [2, 3, 4])
549548
def test_orderedprobit(self, n):

0 commit comments

Comments
 (0)