diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d3bfe644c4..9d4b770140 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2427,6 +2427,12 @@ def dist(cls, nu, *args, **kwargs): nu = at.as_tensor_variable(floatX(nu)) return super().dist([nu], *args, **kwargs) + def get_moment(rv, size, nu): + moment = nu + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + def logcdf(value, nu): """ Compute the log of the cumulative distribution function for ChiSquared distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 36fe591e69..a01b7cc469 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -7,6 +7,7 @@ from pymc.distributions import ( Beta, Cauchy, + ChiSquared, Exponential, Gamma, HalfCauchy, @@ -173,6 +174,20 @@ def test_beta_moment(alpha, beta, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "nu, size, expected", + [ + (1, None, 1), + (1, 5, np.full(5, 1)), + (np.arange(1, 6), None, np.arange(1, 6)), + ], +) +def test_chisquared_moment(nu, size, expected): + with Model() as model: + ChiSquared("x", nu=nu, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "lam, size, expected", [