diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 35fd2bc2f6..d8a8c2f8fd 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -114,9 +114,14 @@ class Binomial(Discrete): def dist(cls, n, p, *args, **kwargs): n = at.as_tensor_variable(intX(n)) p = at.as_tensor_variable(floatX(p)) - # mode = at.cast(tround(n * p), self.dtype) return super().dist([n, p], **kwargs) + def get_moment(rv, size, n, p): + mean = at.round(n * p) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, n, p): r""" Calculate log-probability of Binomial distribution at specified value. @@ -567,9 +572,14 @@ class Poisson(Discrete): @classmethod def dist(cls, mu, *args, **kwargs): mu = at.as_tensor_variable(floatX(mu)) - # mode = intX(at.floor(mu)) return super().dist([mu], *args, **kwargs) + def get_moment(rv, size, mu): + mu = at.floor(mu) + if not rv_size_is_none(size): + mu = at.full(size, mu) + return mu + def logp(value, mu): r""" Calculate log-probability of Poisson distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 36fe591e69..f6eb8c5f1b 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -6,6 +6,7 @@ from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform from pymc.distributions import ( Beta, + Binomial, Cauchy, Exponential, Gamma, @@ -14,6 +15,7 @@ Kumaraswamy, Laplace, LogNormal, + Poisson, StudentT, Weibull, ) @@ -209,7 +211,13 @@ def test_laplace_moment(mu, b, size, expected): (0, 1, 1, None, 0), (0, np.ones(5), 1, None, np.zeros(5)), (np.arange(5), 10, np.arange(1, 6), None, np.arange(5)), - (np.arange(5), 10, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))), + ( + np.arange(5), + 10, + np.arange(1, 6), + (2, 5), + np.full((2, 5), np.arange(5)), + ), ], ) def test_studentt_moment(mu, nu, sigma, size, expected): @@ -318,7 +326,10 @@ def test_gamma_moment(alpha, beta, size, expected): np.arange(1, 6), np.arange(2, 7), (2, 5), - np.full((2, 5), np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6))), + np.full( + (2, 5), + np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6)), + ), ), ], ) @@ -326,3 +337,33 @@ def test_weibull_moment(alpha, beta, size, expected): with Model() as model: Weibull("x", alpha=alpha, beta=beta, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "n, p, size, expected", + [ + (7, 0.7, None, 5), + (7, 0.3, 5, np.full(5, 2)), + (10, np.arange(1, 6) / 10, None, np.arange(1, 6)), + (10, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))), + ], +) +def test_binomial_moment(n, p, size, expected): + with Model() as model: + Binomial("x", n=n, p=p, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "mu, size, expected", + [ + (2.7, None, 2), + (2.3, 5, np.full(5, 2)), + (np.arange(1, 5), None, np.arange(1, 5)), + (np.arange(1, 5), (2, 4), np.full((2, 4), np.arange(1, 5))), + ], +) +def test_poisson_moment(mu, size, expected): + with Model() as model: + Poisson("x", mu=mu, size=size) + assert_moment_is_expected(model, expected)