diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 8e52c812d1..44eebdf1ab 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -48,7 +48,7 @@ from pymc.distributions.distribution import Discrete from pymc.distributions.mixture import Mixture from pymc.distributions.shape_utils import rv_size_is_none -from pymc.logprob.basic import logp +from pymc.logprob.basic import logcdf, logp from pymc.math import sigmoid from pymc.pytensorf import floatX, intX from pymc.vartypes import continuous_types @@ -823,6 +823,10 @@ def logcdf(value, p): def icdf(value, p): res = pt.ceil(pt.log1p(-value) / pt.log1p(-p)).astype("int64") + res_1m = pt.maximum(res - 1, 0) + dist = pm.Geometric.dist(p=p) + value_1m = pt.exp(logcdf(dist, res_1m)) + res = pt.switch(value_1m >= value, res_1m, res) res = check_icdf_value(res, value) return check_icdf_parameters( res, @@ -1060,6 +1064,11 @@ def logcdf(value, lower, upper): def icdf(value, lower, upper): res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1 + res_1m = pt.maximum(res - 1, lower) + dist = pm.DiscreteUniform.dist(lower=lower, upper=upper) + value_1m = pt.exp(logcdf(dist, res_1m)) + res = pt.switch(value_1m >= value, res_1m, res) + res = check_icdf_value(res, value) return check_icdf_parameters( res,