diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b124866b1f..90684bfa79 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -3350,6 +3350,22 @@ def get_nu_b(cls, nu, b, sigma): return nu, b, sigma raise ValueError("Rice distribution must specify either nu" " or b.") + def get_moment(rv, size, nu, sigma): + nu_sigma_ratio = -(nu ** 2) / (2 * sigma ** 2) + mean = ( + sigma + * np.sqrt(np.pi / 2) + * at.exp(nu_sigma_ratio / 2) + * ( + (1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2) + - nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2) + ) + ) + + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, b, sigma): """ Calculate log-probability of Rice distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index defe31afb6..58d24396ba 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -36,6 +36,7 @@ Normal, Pareto, Poisson, + Rice, SkewNormal, StudentT, Triangular, @@ -757,7 +758,7 @@ def test_categorical_moment(p, size, expected): "mu, sigma, size, expected", [ (4.0, 3.0, None, 7.8110885363844345), - (4, np.full(5, 3), None, np.full(5, 7.8110885363844345)), + (4.0, np.full(5, 3), None, np.full(5, 7.8110885363844345)), (np.arange(5), 1, None, np.arange(5) + 1.2703628454614782), (np.arange(5), np.ones(5), (2, 5), np.full((2, 5), np.arange(5) + 1.2703628454614782)), ], @@ -772,7 +773,7 @@ def test_moyal_moment(mu, sigma, size, expected): "alpha, mu, sigma, size, expected", [ (1.0, 1.0, 1.0, None, 1.56418958), - (1, np.ones(5), 1, None, np.full(5, 1.56418958)), + (1.0, np.ones(5), 1.0, None, np.full(5, 1.56418958)), (np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)), ( np.arange(5), @@ -822,3 +823,43 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected): with Model() as model: AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "nu, sigma, size, expected", + [ + (1.0, 1.0, None, 1.5485724605511453), + (1.0, np.ones(5), None, np.full(5, 1.5485724605511453)), + ( + np.arange(1, 6), + 1.0, + None, + ( + 1.5485724605511453, + 2.2723834280687427, + 3.1725772879007166, + 4.127193542536757, + 5.101069639492123, + ), + ), + ( + np.arange(1, 6), + np.ones(5), + (2, 5), + np.full( + (2, 5), + ( + 1.5485724605511453, + 2.2723834280687427, + 3.1725772879007166, + 4.127193542536757, + 5.101069639492123, + ), + ), + ), + ], +) +def test_rice_moment(nu, sigma, size, expected): + with Model() as model: + Rice("x", nu=nu, sigma=sigma, size=size) + assert_moment_is_expected(model, expected)