Skip to content

Commit 7485ccc

Browse files
authored
adds beta-binomial mean and test cases (#5175)
1 parent 073e26b commit 7485ccc

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def dist(cls, alpha, beta, n, *args, **kwargs):
239239
n = at.as_tensor_variable(intX(n))
240240
return super().dist([n, alpha, beta], **kwargs)
241241

242+
def get_moment(rv, size, n, alpha, beta):
243+
mean = at.round((n * alpha) / (alpha + beta))
244+
if not rv_size_is_none(size):
245+
mean = at.full(size, mean)
246+
return mean
247+
242248
def logp(value, n, alpha, beta):
243249
r"""
244250
Calculate log-probability of BetaBinomial distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pymc.distributions import (
77
Bernoulli,
88
Beta,
9+
BetaBinomial,
910
Binomial,
1011
Cauchy,
1112
ChiSquared,
@@ -212,6 +213,21 @@ def test_beta_moment(alpha, beta, size, expected):
212213
assert_moment_is_expected(model, expected)
213214

214215

216+
@pytest.mark.parametrize(
217+
"n, alpha, beta, size, expected",
218+
[
219+
(10, 1, 1, None, 5),
220+
(10, 1, 1, 5, np.full(5, 5)),
221+
(10, 1, np.arange(1, 6), None, np.round(10 / np.arange(2, 7))),
222+
(10, 1, np.arange(1, 6), (2, 5), np.full((2, 5), np.round(10 / np.arange(2, 7)))),
223+
],
224+
)
225+
def test_beta_binomial_moment(alpha, beta, n, size, expected):
226+
with Model() as model:
227+
BetaBinomial("x", alpha=alpha, beta=beta, n=n, size=size)
228+
assert_moment_is_expected(model, expected)
229+
230+
215231
@pytest.mark.parametrize(
216232
"nu, size, expected",
217233
[

0 commit comments

Comments
 (0)