|
26 | 26 | from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
|
27 | 27 | from pymc.exceptions import TruncationError
|
28 | 28 | from pymc.logprob.abstract import _icdf
|
29 |
| -from pymc.logprob.basic import logp |
| 29 | +from pymc.logprob.basic import logcdf, logp |
30 | 30 | from pymc.logprob.transforms import IntervalTransform
|
31 | 31 | from pymc.logprob.utils import ParameterValueError
|
32 | 32 | from pymc.testing import assert_moment_is_expected
|
@@ -165,6 +165,34 @@ def test_truncation_continuous_logp(op_type, lower, upper):
|
165 | 165 | assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v))
|
166 | 166 |
|
167 | 167 |
|
| 168 | +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) |
| 169 | +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) |
| 170 | +def test_truncation_continuous_logcdf(op_type, lower, upper): |
| 171 | + loc = 0.15 |
| 172 | + scale = 10 |
| 173 | + op = icdf_normal if op_type == "icdf" else rejection_normal |
| 174 | + |
| 175 | + x = op(loc, scale, name="x") |
| 176 | + xt = Truncated.dist(x, lower=lower, upper=upper) |
| 177 | + assert isinstance(xt.owner.op, TruncatedRV) |
| 178 | + |
| 179 | + xt_vv = xt.clone() |
| 180 | + xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv)) |
| 181 | + |
| 182 | + ref_xt = scipy.stats.truncnorm( |
| 183 | + (lower - loc) / scale, |
| 184 | + (upper - loc) / scale, |
| 185 | + loc, |
| 186 | + scale, |
| 187 | + ) |
| 188 | + for bound in (lower, upper): |
| 189 | + if np.isinf(bound): |
| 190 | + return |
| 191 | + for offset in (-1, 0, 1): |
| 192 | + test_xt_v = bound + offset |
| 193 | + assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt.logcdf(test_xt_v)) |
| 194 | + |
| 195 | + |
168 | 196 | @pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
|
169 | 197 | @pytest.mark.parametrize("op_type", ["icdf", "rejection"])
|
170 | 198 | def test_truncation_discrete_random(op_type, lower, upper):
|
@@ -232,6 +260,38 @@ def ref_xt_logpmf(value):
|
232 | 260 | assert np.isclose(log_integral, 0.0, atol=1e-5)
|
233 | 261 |
|
234 | 262 |
|
| 263 | +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) |
| 264 | +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) |
| 265 | +def test_truncation_discrete_logcdf(op_type, lower, upper): |
| 266 | + p = 0.7 |
| 267 | + op = icdf_geometric if op_type == "icdf" else rejection_geometric |
| 268 | + |
| 269 | + x = op(p, name="x") |
| 270 | + xt = Truncated.dist(x, lower=lower, upper=upper) |
| 271 | + assert isinstance(xt.owner.op, TruncatedRV) |
| 272 | + |
| 273 | + xt_vv = xt.clone() |
| 274 | + xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv)) |
| 275 | + |
| 276 | + ref_xt = scipy.stats.geom(p) |
| 277 | + log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1)) |
| 278 | + |
| 279 | + def ref_xt_logcdf(value): |
| 280 | + if value < lower: |
| 281 | + return -np.inf |
| 282 | + elif value > upper: |
| 283 | + return 0.0 |
| 284 | + |
| 285 | + return np.log(ref_xt.cdf(value) - ref_xt.cdf(lower - 1)) - log_norm |
| 286 | + |
| 287 | + for bound in (lower, upper): |
| 288 | + if np.isinf(bound): |
| 289 | + continue |
| 290 | + for offset in (-1, 0, 1): |
| 291 | + test_xt_v = bound + offset |
| 292 | + assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt_logcdf(test_xt_v)) |
| 293 | + |
| 294 | + |
235 | 295 | def test_truncation_exceptions():
|
236 | 296 | with pytest.raises(ValueError, match="lower and upper cannot both be None"):
|
237 | 297 | Truncated.dist(pt.random.normal())
|
|
0 commit comments