diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 0b928f229d..f0f031ff02 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1938,6 +1938,13 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs) # mean = median = mode = mu return super().dist([mu, sigma, *covs], **kwargs) + def get_moment(rv, size, mu, covs, chols, evds): + mean = mu + if not rv_size_is_none(size): + moment_size = at.concatenate([size, mu.shape]) + mean = at.full(moment_size, mu) + return mean + def logp(value, mu, sigma, *covs): """ Calculate log-probability of Multivariate Normal distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 2188a931c4..4a6361504e 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -34,6 +34,7 @@ HyperGeometric, Interpolated, InverseGamma, + KroneckerNormal, Kumaraswamy, Laplace, Logistic, @@ -110,7 +111,6 @@ def test_all_distributions_have_moments(): dist_module.discrete.DiscreteWeibull, dist_module.multivariate.CAR, dist_module.multivariate.DirichletMultinomial, - dist_module.multivariate.KroneckerNormal, dist_module.multivariate.Wishart, } @@ -1316,3 +1316,32 @@ def normal_sim(rng, mu, sigma, size): cutoff = st.norm().ppf(1 - (alpha / 2)) assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff) + + +@pytest.mark.parametrize( + "mu, covs, size, expected", + [ + (np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)), + (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))), + (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))), + (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))), + ( + np.array([1, 2, 3, 4]), + [ + np.array([[1.0, 0.5], [0.5, 2]]), + np.array([[1.0, 0.4], [0.4, 2]]), + ], + 2, + np.array( + [ + [1, 2, 3, 4], + [1, 2, 3, 4], + ] + ), + ), + ], +) +def test_kronecker_normal_moments(mu, covs, size, expected): + with Model() as model: + KroneckerNormal("x", mu=mu, covs=covs, size=size) + assert_moment_is_expected(model, expected)