Skip to content

Commit 55e8fe9

Browse files
committed
Add regression test for Truncated Gamma
1 parent f2bf1e0 commit 55e8fe9

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/distributions/test_truncated.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.tensor.random.basic import GeometricRV, NormalRV
2121

2222
from pymc import Censored, Model, draw, find_MAP
23-
from pymc.distributions.continuous import Exponential, TruncatedNormalRV
23+
from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV
2424
from pymc.distributions.shape_utils import change_dist_size
2525
from pymc.distributions.transforms import _default_transform
2626
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
@@ -392,3 +392,33 @@ def test_truncated_inference():
392392
map = find_MAP(progressbar=False)
393393

394394
assert np.isclose(map["lam"], lam_true, atol=0.1)
395+
396+
397+
def test_truncated_gamma():
398+
# Regression test for https://github.com/pymc-devs/pymc/issues/6931
399+
alpha = 3.0
400+
beta = 3.0
401+
upper = 2.5
402+
x = np.linspace(0.0, upper + 0.5, 100)
403+
404+
gamma_scipy = scipy.stats.gamma(a=alpha, scale=1.0 / beta)
405+
logp_scipy = gamma_scipy.logpdf(x) - gamma_scipy.logcdf(upper)
406+
logp_scipy[x > upper] = -np.inf
407+
408+
gamma_trunc_pymc = Truncated.dist(
409+
Gamma.dist(alpha=alpha, beta=beta),
410+
upper=upper,
411+
)
412+
logp_pymc = logp(gamma_trunc_pymc, x).eval()
413+
np.testing.assert_allclose(
414+
logp_pymc,
415+
logp_scipy,
416+
)
417+
418+
# Changing the size used to invert the beta Gamma parameter again
419+
resized_gamma_trunc_pymc = change_dist_size(gamma_trunc_pymc, new_size=x.shape)
420+
logp_resized_pymc = logp(resized_gamma_trunc_pymc, x).eval()
421+
np.testing.assert_allclose(
422+
logp_resized_pymc,
423+
logp_scipy,
424+
)

0 commit comments

Comments
 (0)