Skip to content

Commit 2fcce43

Browse files
gokuldricardoV94
authored andcommitted
Added ICDF for the discrete uniform distribution.
1 parent 067d89b commit 2fcce43

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc/distributions/discrete.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,15 @@ def logcdf(value, lower, upper):
10581058
msg="lower <= upper",
10591059
)
10601060

1061+
def icdf(value, lower, upper):
1062+
res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1
1063+
res = check_icdf_value(res, value)
1064+
return check_icdf_parameters(
1065+
res,
1066+
lower <= upper,
1067+
msg="lower <= upper",
1068+
)
1069+
10611070

10621071
class Categorical(Discrete):
10631072
R"""

tests/distributions/test_discrete.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import pymc as pm
3030

3131
from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit
32-
from pymc.logprob.abstract import logcdf
32+
from pymc.logprob.abstract import icdf, logcdf
3333
from pymc.logprob.joint_logprob import logp
3434
from pymc.logprob.utils import ParameterValueError
3535
from pymc.pytensorf import floatX
@@ -118,13 +118,21 @@ def test_discrete_unif(self):
118118
Domain([-10, 0, 10], "int64"),
119119
{"lower": -Rplusdunif, "upper": Rplusdunif},
120120
)
121+
check_icdf(
122+
pm.DiscreteUniform,
123+
{"lower": -Rplusdunif, "upper": Rplusdunif},
124+
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
125+
skip_paramdomain_outside_edge_test=True,
126+
)
121127
# Custom logp / logcdf check for invalid parameters
122128
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
123129
with pytensor.config.change_flags(mode=Mode("py")):
124130
with pytest.raises(ParameterValueError):
125131
logp(invalid_dist, 0.5).eval()
126132
with pytest.raises(ParameterValueError):
127133
logcdf(invalid_dist, 2).eval()
134+
with pytest.raises(ParameterValueError):
135+
icdf(invalid_dist, np.array(1)).eval()
128136

129137
def test_geometric(self):
130138
check_logp(

0 commit comments

Comments
 (0)