Skip to content

Commit 910d9ef

Browse files
authored
Add logcdf implementation for Truncated distributions (#6690)
1 parent 9f01be2 commit 910d9ef

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

pymc/distributions/truncated.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,50 @@ def truncated_logprob(op, values, *inputs, **kwargs):
352352
return logp
353353

354354

355+
@_logcdf.register(TruncatedRV)
356+
def truncated_logcdf(op, value, *inputs, **kwargs):
357+
*rv_inputs, lower, upper, rng = inputs
358+
rv_inputs = [rng, *rv_inputs]
359+
360+
base_rv_op = op.base_rv_op
361+
logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)
362+
363+
# For left truncated discrete RVs, we don't want to include the lower bound in the
364+
# normalization term
365+
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
366+
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
367+
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
368+
369+
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
370+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
371+
372+
lognorm = 0
373+
if is_lower_bounded and is_upper_bounded:
374+
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
375+
elif is_lower_bounded:
376+
lognorm = pt.log1mexp(lower_logcdf)
377+
elif is_upper_bounded:
378+
lognorm = upper_logcdf
379+
380+
logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
381+
logcdf_trunc = logcdf_numerator - lognorm
382+
383+
if is_lower_bounded:
384+
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)
385+
386+
if is_upper_bounded:
387+
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)
388+
389+
if is_lower_bounded and is_upper_bounded:
390+
logcdf_trunc = check_parameters(
391+
logcdf_trunc,
392+
pt.le(lower, upper),
393+
msg="lower_bound <= upper_bound",
394+
)
395+
396+
return logcdf_trunc
397+
398+
355399
@_truncated.register(NormalRV)
356400
def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma):
357401
return TruncatedNormal.dist(

tests/distributions/test_truncated.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
2727
from pymc.exceptions import TruncationError
2828
from pymc.logprob.abstract import _icdf
29-
from pymc.logprob.basic import logp
29+
from pymc.logprob.basic import logcdf, logp
3030
from pymc.logprob.transforms import IntervalTransform
3131
from pymc.logprob.utils import ParameterValueError
3232
from pymc.testing import assert_moment_is_expected
@@ -165,6 +165,34 @@ def test_truncation_continuous_logp(op_type, lower, upper):
165165
assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v))
166166

167167

168+
@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
169+
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
170+
def test_truncation_continuous_logcdf(op_type, lower, upper):
171+
loc = 0.15
172+
scale = 10
173+
op = icdf_normal if op_type == "icdf" else rejection_normal
174+
175+
x = op(loc, scale, name="x")
176+
xt = Truncated.dist(x, lower=lower, upper=upper)
177+
assert isinstance(xt.owner.op, TruncatedRV)
178+
179+
xt_vv = xt.clone()
180+
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))
181+
182+
ref_xt = scipy.stats.truncnorm(
183+
(lower - loc) / scale,
184+
(upper - loc) / scale,
185+
loc,
186+
scale,
187+
)
188+
for bound in (lower, upper):
189+
if np.isinf(bound):
190+
return
191+
for offset in (-1, 0, 1):
192+
test_xt_v = bound + offset
193+
assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt.logcdf(test_xt_v))
194+
195+
168196
@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
169197
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
170198
def test_truncation_discrete_random(op_type, lower, upper):
@@ -232,6 +260,38 @@ def ref_xt_logpmf(value):
232260
assert np.isclose(log_integral, 0.0, atol=1e-5)
233261

234262

263+
@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
264+
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
265+
def test_truncation_discrete_logcdf(op_type, lower, upper):
266+
p = 0.7
267+
op = icdf_geometric if op_type == "icdf" else rejection_geometric
268+
269+
x = op(p, name="x")
270+
xt = Truncated.dist(x, lower=lower, upper=upper)
271+
assert isinstance(xt.owner.op, TruncatedRV)
272+
273+
xt_vv = xt.clone()
274+
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))
275+
276+
ref_xt = scipy.stats.geom(p)
277+
log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1))
278+
279+
def ref_xt_logcdf(value):
280+
if value < lower:
281+
return -np.inf
282+
elif value > upper:
283+
return 0.0
284+
285+
return np.log(ref_xt.cdf(value) - ref_xt.cdf(lower - 1)) - log_norm
286+
287+
for bound in (lower, upper):
288+
if np.isinf(bound):
289+
continue
290+
for offset in (-1, 0, 1):
291+
test_xt_v = bound + offset
292+
assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt_logcdf(test_xt_v))
293+
294+
235295
def test_truncation_exceptions():
236296
with pytest.raises(ValueError, match="lower and upper cannot both be None"):
237297
Truncated.dist(pt.random.normal())

0 commit comments

Comments
 (0)