diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 90684bfa79..516dda542c 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -30,6 +30,7 @@ from aesara.graph.op import Op from aesara.tensor import gammaln from aesara.tensor.extra_ops import broadcast_shape +from aesara.tensor.math import tanh from aesara.tensor.random.basic import ( BetaRV, WeibullRV, @@ -3985,6 +3986,12 @@ def dist(cls, h=1.0, z=0.0, **kwargs): return super().dist([h, z], **kwargs) + def get_moment(rv, size, h, z): + mean = at.switch(at.eq(z, 0), h / 4, tanh(z / 2) * (h / (2 * z))) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, h, z): """ Calculate log-probability of Polya-Gamma distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index fd37eaac8b..777be4a8bc 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -41,6 +41,7 @@ Normal, Pareto, Poisson, + PolyaGamma, Rice, SkewNormal, StudentT, @@ -984,3 +985,54 @@ def _random(mu, rng=None, size=None): match="Cannot safely infer the size of a multivariate random variable's moment.", ): evaled_moment = get_moment(a).eval({mu: mu_val}) + + +@pytest.mark.parametrize( + "h, z, size, expected", + [ + (1.0, 0.0, None, 0.25), + ( + 1.0, + np.arange(5), + None, + ( + 0.25, + 0.23105857863000487, + 0.1903985389889412, + 0.1508580422741444, + 0.12050344750947711, + ), + ), + ( + np.arange(1, 6), + np.arange(5), + None, + ( + 0.25, + 0.46211715726000974, + 0.5711956169668236, + 0.6034321690965776, + 0.6025172375473855, + ), + ), + ( + np.arange(1, 6), + np.arange(5), + (2, 5), + np.full( + (2, 5), + ( + 0.25, + 0.46211715726000974, + 0.5711956169668236, + 0.6034321690965776, + 0.6025172375473855, + ), + ), + ), + ], +) +def test_polyagamma_moment(h, z, size, expected): + with Model() as model: + PolyaGamma("x", h=h, z=z, size=size) + assert_moment_is_expected(model, expected)