diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 86e7f1e9b9..e190bf7a12 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -239,6 +239,12 @@ def dist(cls, alpha, beta, n, *args, **kwargs): n = at.as_tensor_variable(intX(n)) return super().dist([n, alpha, beta], **kwargs) + def get_moment(rv, size, n, alpha, beta): + mean = at.round((n * alpha) / (alpha + beta)) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, n, alpha, beta): r""" Calculate log-probability of BetaBinomial distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index bbe478f8ac..20ea2bc2f4 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -6,6 +6,7 @@ from pymc.distributions import ( Bernoulli, Beta, + BetaBinomial, Binomial, Cauchy, ChiSquared, @@ -209,6 +210,21 @@ def test_beta_moment(alpha, beta, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "n, alpha, beta, size, expected", + [ + (10, 1, 1, None, 5), + (10, 1, 1, 5, np.full(5, 5)), + (10, 1, np.arange(1, 6), None, np.round(10 / np.arange(2, 7))), + (10, 1, np.arange(1, 6), (2, 5), np.full((2, 5), np.round(10 / np.arange(2, 7)))), + ], +) +def test_beta_binomial_moment(alpha, beta, n, size, expected): + with Model() as model: + BetaBinomial("x", alpha=alpha, beta=beta, n=n, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "nu, size, expected", [