From 2348e4a4ebaf322813433d5199f92d671bd8ac0f Mon Sep 17 00:00:00 2001 From: Gokul D Date: Fri, 14 Apr 2023 00:41:47 +0530 Subject: [PATCH] Fix numerical precision issues in discrete ICDFs. --- pymc/distributions/discrete.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 8e52c812d1..44eebdf1ab 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -48,7 +48,7 @@ from pymc.distributions.distribution import Discrete from pymc.distributions.mixture import Mixture from pymc.distributions.shape_utils import rv_size_is_none -from pymc.logprob.basic import logp +from pymc.logprob.basic import logcdf, logp from pymc.math import sigmoid from pymc.pytensorf import floatX, intX from pymc.vartypes import continuous_types @@ -823,6 +823,10 @@ def logcdf(value, p): def icdf(value, p): res = pt.ceil(pt.log1p(-value) / pt.log1p(-p)).astype("int64") + res_1m = pt.maximum(res - 1, 0) + dist = pm.Geometric.dist(p=p) + value_1m = pt.exp(logcdf(dist, res_1m)) + res = pt.switch(value_1m >= value, res_1m, res) res = check_icdf_value(res, value) return check_icdf_parameters( res, @@ -1060,6 +1064,11 @@ def logcdf(value, lower, upper): def icdf(value, lower, upper): res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1 + res_1m = pt.maximum(res - 1, lower) + dist = pm.DiscreteUniform.dist(lower=lower, upper=upper) + value_1m = pt.exp(logcdf(dist, res_1m)) + res = pt.switch(value_1m >= value, res_1m, res) + res = check_icdf_value(res, value) return check_icdf_parameters( res,