Skip to content

Commit 473c952

Browse files
michaelraczyckiMichal Raczycki
and
Michal Raczycki
authored
Implement icdf for Univariate distribution (#6528)
Also extended testing.check_icdf with skip_paradomain_outside_edge_test param --------- Co-authored-by: Michal Raczycki <michalraczycki@macbook-pro-michal.home>
1 parent 67925df commit 473c952

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

pymc/distributions/continuous.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,11 @@ def logcdf(value, lower, upper):
345345
msg="lower <= upper",
346346
)
347347

348+
def icdf(value, lower, upper):
349+
res = lower + (upper - lower) * value
350+
res = check_icdf_value(res, value)
351+
return check_icdf_parameters(res, lower < upper)
352+
348353

349354
@_default_transform.register(Uniform)
350355
def uniform_default_transform(op, rv):

pymc/testing.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def check_icdf(
526526
pymc_dist: Distribution,
527527
paramdomains: Dict[str, Domain],
528528
scipy_icdf: Callable,
529+
skip_paramdomain_outside_edge_test=False,
529530
decimal: Optional[int] = None,
530531
n_samples: int = 100,
531532
) -> None:
@@ -548,7 +549,7 @@ def check_icdf(
548549
paramdomains : Dictionary of Parameter : Domain pairs
549550
Supported domains of distribution parameters
550551
scipy_icdf : Scipy icdf method
551-
Scipy icdf (ppp) method of equivalent pymc_dist distribution
552+
Scipy icdf (ppf) method of equivalent pymc_dist distribution
552553
decimal : int, optional
553554
Level of precision with which pymc_dist and scipy_icdf are compared.
554555
Defaults to 6 for float64 and 3 for float32
@@ -557,6 +558,9 @@ def check_icdf(
557558
are compared between pymc and scipy methods. If n_samples is below the
558559
total number of combinations, a random subset is evaluated. Setting
559560
n_samples = -1, will return all possible combinations. Defaults to 100
561+
skip_paradomain_outside_edge_test : Bool
562+
Whether to run test 2., which checks that pymc distribution icdf
563+
returns nan for invalid parameter values outside the supported domain edge
560564
561565
"""
562566
if decimal is None:
@@ -586,19 +590,20 @@ def check_icdf(
586590
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
587591
valid_params["q"] = valid_value
588592

589-
# Test pymc distribution raises ParameterValueError for parameters outside the
590-
# supported domain edges (excluding edges)
591-
invalid_params = find_invalid_scalar_params(paramdomains)
592-
for invalid_param, invalid_edges in invalid_params.items():
593-
for invalid_edge in invalid_edges:
594-
if invalid_edge is None:
595-
continue
593+
if not skip_paramdomain_outside_edge_test:
594+
# Test pymc distribution raises ParameterValueError for parameters outside the
595+
# supported domain edges (excluding edges)
596+
invalid_params = find_invalid_scalar_params(paramdomains)
597+
for invalid_param, invalid_edges in invalid_params.items():
598+
for invalid_edge in invalid_edges:
599+
if invalid_edge is None:
600+
continue
596601

597-
point = valid_params.copy()
598-
point[invalid_param] = invalid_edge
599-
with pytest.raises(ParameterValueError):
600-
pymc_icdf(**point)
601-
pytest.fail(f"test_params={point}")
602+
point = valid_params.copy()
603+
point[invalid_param] = invalid_edge
604+
with pytest.raises(ParameterValueError):
605+
pymc_icdf(**point)
606+
pytest.fail(f"test_params={point}")
602607

603608
# Test that values below 0 or above 1 evaluate to nan
604609
invalid_values = find_invalid_scalar_params({"q": domain})["q"]

tests/distributions/test_continuous.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626

2727
import pymc as pm
2828

29-
from pymc.distributions.continuous import Normal, get_tau_sigma, interpolated
29+
from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated
3030
from pymc.distributions.dist_math import clipped_beta_rvs
31-
from pymc.logprob.abstract import logcdf
31+
from pymc.logprob.abstract import icdf, logcdf
3232
from pymc.logprob.joint_logprob import logp
3333
from pymc.logprob.utils import ParameterValueError
3434
from pymc.pytensorf import floatX
@@ -176,13 +176,21 @@ def test_uniform(self):
176176
lambda value, lower, upper: st.uniform.logcdf(value, lower, upper - lower),
177177
skip_paramdomain_outside_edge_test=True,
178178
)
179+
check_icdf(
180+
pm.Uniform,
181+
{"lower": -Rplusunif, "upper": Rplusunif},
182+
lambda q, lower, upper: st.uniform.ppf(q=q, loc=lower, scale=upper - lower),
183+
skip_paramdomain_outside_edge_test=True,
184+
)
179185
# Custom logp / logcdf check for invalid parameters
180186
invalid_dist = pm.Uniform.dist(lower=1, upper=0)
181187
with pytensor.config.change_flags(mode=Mode("py")):
182188
with pytest.raises(ParameterValueError):
183189
logp(invalid_dist, np.array(0.5)).eval()
184190
with pytest.raises(ParameterValueError):
185191
logcdf(invalid_dist, np.array(0.5)).eval()
192+
with pytest.raises(ParameterValueError):
193+
icdf(invalid_dist, np.array(0.5)).eval()
186194

187195
def test_triangular(self):
188196
check_logp(

0 commit comments

Comments
 (0)