Skip to content

Commit cd9f9d1

Browse files
add poisson and binomial moment
1 parent 2efedc1 commit cd9f9d1

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

pymc/distributions/discrete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def dist(cls, n, p, *args, **kwargs):
117117
# mode = at.cast(tround(n * p), self.dtype)
118118
return super().dist([n, p], **kwargs)
119119

120+
def get_moment(rv, size, n, p):
121+
mean = n * p
122+
if not rv_size_is_none(size):
123+
mean = at.full(size, mean)
124+
return mean
125+
120126
def logp(value, n, p):
121127
r"""
122128
Calculate log-probability of Binomial distribution at specified value.
@@ -570,6 +576,11 @@ def dist(cls, mu, *args, **kwargs):
570576
# mode = intX(at.floor(mu))
571577
return super().dist([mu], *args, **kwargs)
572578

579+
def get_moment(rv, size, mu):
580+
if not rv_size_is_none(size):
581+
mu = at.full(size, mu)
582+
return mu
583+
573584
def logp(value, mu):
574585
r"""
575586
Calculate log-probability of Poisson distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform
77
from pymc.distributions import (
88
Beta,
9+
Binomial,
910
Cauchy,
1011
Exponential,
1112
Gamma,
@@ -14,6 +15,7 @@
1415
Kumaraswamy,
1516
Laplace,
1617
LogNormal,
18+
Poisson,
1719
StudentT,
1820
Weibull,
1921
)
@@ -209,7 +211,13 @@ def test_laplace_moment(mu, b, size, expected):
209211
(0, 1, 1, None, 0),
210212
(0, np.ones(5), 1, None, np.zeros(5)),
211213
(np.arange(5), 10, np.arange(1, 6), None, np.arange(5)),
212-
(np.arange(5), 10, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
214+
(
215+
np.arange(5),
216+
10,
217+
np.arange(1, 6),
218+
(2, 5),
219+
np.full((2, 5), np.arange(5)),
220+
),
213221
],
214222
)
215223
def test_studentt_moment(mu, nu, sigma, size, expected):
@@ -318,11 +326,44 @@ def test_gamma_moment(alpha, beta, size, expected):
318326
np.arange(1, 6),
319327
np.arange(2, 7),
320328
(2, 5),
321-
np.full((2, 5), np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6))),
329+
np.full(
330+
(2, 5),
331+
np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6)),
332+
),
322333
),
323334
],
324335
)
325336
def test_weibull_moment(alpha, beta, size, expected):
326337
with Model() as model:
327338
Weibull("x", alpha=alpha, beta=beta, size=size)
328339
assert_moment_is_expected(model, expected)
340+
341+
342+
@pytest.mark.parametrize(
343+
"n, p, size, expected",
344+
[
345+
(10, 0.5, None, 5),
346+
(10, 0.5, 5, np.full(5, 5)),
347+
(10, np.arange(1, 6) / 10, None, np.arange(1, 6)),
348+
(10, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))),
349+
],
350+
)
351+
def test_binomial_moment(n, p, size, expected):
352+
with Model() as model:
353+
Binomial("x", n=n, p=p, size=size)
354+
assert_moment_is_expected(model, expected)
355+
356+
357+
@pytest.mark.parametrize(
358+
"mu, size, expected",
359+
[
360+
(2, None, 2),
361+
(2, 5, np.full(5, 2)),
362+
(np.arange(1, 5), None, np.arange(1, 5)),
363+
(np.arange(1, 5), (2, 4), np.full((2, 4), np.arange(1, 5))),
364+
],
365+
)
366+
def test_poisson_moment(mu, size, expected):
367+
with Model() as model:
368+
Poisson("x", mu=mu, size=size)
369+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)