Skip to content

Commit 63641a7

Browse files
adapted uniform.icdf tests to run on new test suite, extended testing.py check_icdf with skip_paradomain_outside_edge_test param
1 parent ecbb89a commit 63641a7

File tree

3 files changed

+27
-73
lines changed

3 files changed

+27
-73
lines changed

pymc/distributions/continuous.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -739,27 +739,6 @@ def logp(value, mu, sigma, lower, upper):
739739

740740
return logp
741741

742-
def icdf(value, mu, sigma, lower, upper):
743-
value = np.asarray(value)
744-
eps = np.finfo(float).eps
745-
746-
# Standardize the lower and upper bounds
747-
lower_std = (lower - mu) / sigma
748-
upper_std = (upper - mu) / sigma
749-
750-
# Compute the cdf of the standard normal distribution at the bounds
751-
Phi_a = 0.5 * (1 + at.erf(lower_std / at.sqrt(2)))
752-
Phi_b = 0.5 * (1 + at.erf(upper_std / at.sqrt(2)))
753-
754-
# Compute the cdf of the truncated normal distribution at the quantiles
755-
Phi_q = Phi_a + value * (Phi_b - Phi_a)
756-
757-
# Invert the cdf of the standard normal distribution at the quantiles
758-
z = at.erfinv(2 * Phi_q - 1) * np.sqrt(2)
759-
760-
# Transform the samples back to the truncated normal distribution
761-
return mu + sigma * at.clip(z, lower_std + eps, upper_std - eps)
762-
763742

764743
@_default_transform.register(TruncatedNormal)
765744
def truncated_normal_default_transform(op, rv):

pymc/testing.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def check_icdf(
526526
pymc_dist: Distribution,
527527
paramdomains: Dict[str, Domain],
528528
scipy_icdf: Callable,
529+
skip_paramdomain_outside_edge_test=False,
529530
decimal: Optional[int] = None,
530531
n_samples: int = 100,
531532
) -> None:
@@ -548,7 +549,7 @@ def check_icdf(
548549
paramdomains : Dictionary of Parameter : Domain pairs
549550
Supported domains of distribution parameters
550551
scipy_icdf : Scipy icdf method
551-
Scipy icdf (ppp) method of equivalent pymc_dist distribution
552+
Scipy icdf (ppf) method of equivalent pymc_dist distribution
552553
decimal : int, optional
553554
Level of precision with which pymc_dist and scipy_icdf are compared.
554555
Defaults to 6 for float64 and 3 for float32
@@ -557,6 +558,9 @@ def check_icdf(
557558
are compared between pymc and scipy methods. If n_samples is below the
558559
total number of combinations, a random subset is evaluated. Setting
559560
n_samples = -1, will return all possible combinations. Defaults to 100
561+
skip_paradomain_outside_edge_test : Bool
562+
Whether to run test 2., which checks that pymc distribution icdf
563+
returns nan for invalid parameter values outside the supported domain edge
560564
561565
"""
562566
if decimal is None:
@@ -586,19 +590,20 @@ def check_icdf(
586590
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
587591
valid_params["q"] = valid_value
588592

589-
# Test pymc distribution raises ParameterValueError for parameters outside the
590-
# supported domain edges (excluding edges)
591-
invalid_params = find_invalid_scalar_params(paramdomains)
592-
for invalid_param, invalid_edges in invalid_params.items():
593-
for invalid_edge in invalid_edges:
594-
if invalid_edge is None:
595-
continue
593+
if not skip_paramdomain_outside_edge_test:
594+
# Test pymc distribution raises ParameterValueError for parameters outside the
595+
# supported domain edges (excluding edges)
596+
invalid_params = find_invalid_scalar_params(paramdomains)
597+
for invalid_param, invalid_edges in invalid_params.items():
598+
for invalid_edge in invalid_edges:
599+
if invalid_edge is None:
600+
continue
596601

597-
point = valid_params.copy()
598-
point[invalid_param] = invalid_edge
599-
with pytest.raises(ParameterValueError):
600-
pymc_icdf(**point)
601-
pytest.fail(f"test_params={point}")
602+
point = valid_params.copy()
603+
point[invalid_param] = invalid_edge
604+
with pytest.raises(ParameterValueError):
605+
pymc_icdf(**point)
606+
pytest.fail(f"test_params={point}")
602607

603608
# Test that values below 0 or above 1 evaluate to nan
604609
invalid_values = find_invalid_scalar_params({"q": domain})["q"]

tests/distributions/test_continuous.py

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated
3030
from pymc.distributions.dist_math import clipped_beta_rvs
31-
from pymc.logprob.abstract import logcdf
31+
from pymc.logprob.abstract import icdf, logcdf
3232
from pymc.logprob.joint_logprob import logp
3333
from pymc.logprob.utils import ParameterValueError
3434
from pymc.pytensorf import floatX
@@ -176,13 +176,21 @@ def test_uniform(self):
176176
lambda value, lower, upper: st.uniform.logcdf(value, lower, upper - lower),
177177
skip_paramdomain_outside_edge_test=True,
178178
)
179+
check_icdf(
180+
pm.Uniform,
181+
{"lower": -Rplusunif, "upper": Rplusunif},
182+
lambda q, lower, upper: st.uniform.ppf(q=q, loc=lower, scale=upper - lower),
183+
skip_paramdomain_outside_edge_test=True,
184+
)
179185
# Custom logp / logcdf check for invalid parameters
180186
invalid_dist = pm.Uniform.dist(lower=1, upper=0)
181187
with pytensor.config.change_flags(mode=Mode("py")):
182188
with pytest.raises(ParameterValueError):
183189
logp(invalid_dist, np.array(0.5)).eval()
184190
with pytest.raises(ParameterValueError):
185191
logcdf(invalid_dist, np.array(0.5)).eval()
192+
with pytest.raises(ParameterValueError):
193+
icdf(invalid_dist, np.array(0.5)).eval()
186194

187195
def test_triangular(self):
188196
check_logp(
@@ -2277,41 +2285,3 @@ def dist(cls, **kwargs):
22772285
extra_args={"rng": pytensor.shared(rng)},
22782286
ref_rand=ref_rand,
22792287
)
2280-
2281-
2282-
class TestICDF:
2283-
@pytest.mark.parametrize(
2284-
"dist_params, obs, size",
2285-
[
2286-
((0, 1), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
2287-
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
2288-
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), (2, 3)),
2289-
],
2290-
)
2291-
def test_normal_icdf(self, dist_params, obs, size):
2292-
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
2293-
dist_params = dict(zip(dist_params_at, dist_params))
2294-
2295-
x = Normal.dist(*dist_params_at, size=size_at)
2296-
2297-
scipy_logprob_tester(x, obs, dist_params, test_fn=st.norm.ppf, test="icdf")
2298-
2299-
@pytest.mark.parametrize(
2300-
"dist_params, obs, size",
2301-
[
2302-
((-5, 4), np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=np.float64), ()),
2303-
((-1, 2), np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=np.float64), ()),
2304-
((0, 10), np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=np.float64), ()),
2305-
],
2306-
)
2307-
def test_uniform_icdf(self, dist_params, obs, size):
2308-
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
2309-
dist_params = dict(zip(dist_params_at, dist_params))
2310-
x = Uniform.dist(*dist_params_at)
2311-
scipy_logprob_tester(
2312-
x,
2313-
obs,
2314-
dist_params,
2315-
test_fn=lambda val, lower, upper: st.uniform.ppf(val, lower, upper - lower),
2316-
test="icdf",
2317-
)

0 commit comments

Comments
 (0)