Skip to content

Commit bdd9a65

Browse files
committed
Add icdf functions for Beta, Gamma, Chisquared and Students distributions
removing unnecessary parameterization test on studentT test removing unnecessary test on chisquared test
1 parent a033261 commit bdd9a65

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

pymc/distributions/continuous.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pytensor.raise_op import Assert
3434
from pytensor.tensor import gammaln
3535
from pytensor.tensor.extra_ops import broadcast_shape
36-
from pytensor.tensor.math import tanh
36+
from pytensor.tensor.math import betaincinv, gammaincinv, tanh
3737
from pytensor.tensor.random.basic import (
3838
BetaRV,
3939
_gamma,
@@ -1227,6 +1227,16 @@ def logcdf(value, alpha, beta):
12271227
msg="alpha > 0, beta > 0",
12281228
)
12291229

1230+
def icdf(value, alpha, beta):
1231+
res = betaincinv(alpha, beta, value)
1232+
res = check_icdf_value(res, value)
1233+
return check_icdf_parameters(
1234+
res,
1235+
alpha > 0,
1236+
beta > 0,
1237+
msg="alpha > 0, beta > 0",
1238+
)
1239+
12301240

12311241
class KumaraswamyRV(RandomVariable):
12321242
name = "kumaraswamy"
@@ -1872,6 +1882,21 @@ def logcdf(value, nu, mu, sigma):
18721882
msg="nu > 0, sigma > 0",
18731883
)
18741884

1885+
def icdf(value, nu, mu, sigma):
1886+
res = pt.switch(
1887+
pt.lt(value, 0.5),
1888+
-pt.sqrt(nu) * pt.sqrt((1.0 / betaincinv(nu * 0.5, 0.5, 2.0 * value)) - 1.0),
1889+
pt.sqrt(nu) * pt.sqrt((1.0 / betaincinv(nu * 0.5, 0.5, 2.0 * (1 - value))) - 1.0),
1890+
)
1891+
res = mu + res * sigma
1892+
res = check_icdf_value(res, value)
1893+
return check_icdf_parameters(
1894+
res,
1895+
nu > 0,
1896+
sigma > 0,
1897+
msg="nu > 0, sigma > 0",
1898+
)
1899+
18751900

18761901
class Pareto(BoundedContinuous):
18771902
r"""
@@ -2272,6 +2297,16 @@ def logcdf(value, alpha, scale):
22722297
)
22732298
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
22742299

2300+
def icdf(value, alpha, scale):
2301+
res = scale * gammaincinv(alpha, value)
2302+
res = check_icdf_value(res, value)
2303+
return check_icdf_parameters(
2304+
res,
2305+
alpha > 0,
2306+
scale > 0,
2307+
msg="alpha > 0, beta > 0",
2308+
)
2309+
22752310

22762311
class InverseGamma(PositiveContinuous):
22772312
r"""

tests/distributions/test_continuous.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,13 @@ def test_beta_logcdf(self):
411411
lambda value, alpha, beta: st.beta.logcdf(value, alpha, beta),
412412
)
413413

414+
def test_beta_icdf(self):
415+
check_icdf(
416+
pm.Beta,
417+
{"alpha": Rplus, "beta": Rplus},
418+
lambda q, alpha, beta: st.beta.ppf(q, alpha, beta),
419+
)
420+
414421
def test_kumaraswamy(self):
415422
# Scipy does not have a built-in Kumaraswamy
416423
def scipy_log_pdf(value, a, b):
@@ -557,6 +564,13 @@ def test_studentt_logcdf(self):
557564
lambda value, nu, mu, sigma: st.t.logcdf(value, nu, mu, sigma),
558565
)
559566

567+
def test_studentt_icdf(self):
568+
check_icdf(
569+
pm.StudentT,
570+
{"nu": Rplusbig, "mu": R, "sigma": Rplusbig},
571+
lambda q, nu, mu, sigma: st.t.ppf(q, nu, mu, sigma),
572+
)
573+
560574
def test_cauchy(self):
561575
check_logp(
562576
pm.Cauchy,
@@ -623,6 +637,13 @@ def test_gamma_logcdf(self):
623637
lambda value, alpha, beta: st.gamma.logcdf(value, alpha, scale=1.0 / beta),
624638
)
625639

640+
def test_gamma_icdf(self):
641+
check_icdf(
642+
pm.Gamma,
643+
{"alpha": Rplusbig, "beta": Rplusbig},
644+
lambda q, alpha, beta: st.gamma.ppf(q, alpha, scale=1.0 / beta),
645+
)
646+
626647
def test_inverse_gamma_logp(self):
627648
check_logp(
628649
pm.InverseGamma,

0 commit comments

Comments
 (0)