diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 729d393c08..4dbc1ea9b8 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -354,6 +354,13 @@ def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True assert_negative_support(nu, "nu", "MvStudentT") return super().dist([nu, mu, cov], **kwargs) + def get_moment(rv, size, nu, mu, cov): + moment = mu + if not rv_size_is_none(size): + moment_size = at.concatenate([size, moment.shape]) + moment = at.full(moment_size, moment) + return moment + def logp(value, nu, mu, cov): """ Calculate log-probability of Multivariate Student's T distribution @@ -692,7 +699,7 @@ def dist(cls, eta, cutpoints, n, *args, **kwargs): class OrderedMultinomial: - R""" + r""" Wrapper class for Ordered Multinomial distributions. Useful for regression on ordinal data whose values range @@ -1727,6 +1734,13 @@ def dist( return super().dist([mu, rowchol_cov, colchol_cov], **kwargs) + def get_moment(rv, size, mu, rowchol, colchol): + output_shape = (rowchol.shape[0], colchol.shape[0]) + if not rv_size_is_none(size): + output_shape = at.concatenate([size, output_shape]) + moment = at.full(output_shape, mu) + return moment + def logp(value, mu, rowchol, colchol): """ Calculate log-probability of Matrix-valued Normal distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index e1baa8840a..8b851c0044 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -39,7 +39,9 @@ Logistic, LogitNormal, LogNormal, + MatrixNormal, Moyal, + MvStudentT, NegativeBinomial, Normal, Pareto, @@ -830,6 +832,33 @@ def test_moyal_moment(mu, sigma, size, expected): assert_moment_is_expected(model, expected) +rand1d = np.random.rand(2) +rand2d = np.random.rand(2, 3) + + +@pytest.mark.parametrize( + "nu, mu, cov, size, expected", + [ + (2, np.ones(1), np.eye(1), None, np.ones(1)), + (2, rand1d, np.eye(2), None, rand1d), + (2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)), + (2, rand1d, np.eye(2), (2, 5), np.full((2, 5, 2), rand1d)), + (2, rand2d, np.eye(3), None, rand2d), + (2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)), + (2, rand2d, np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)), + ], +) +def test_mvstudentt_moment(nu, mu, cov, size, expected): + with Model() as model: + MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size) + assert_moment_is_expected(model, expected) + + +def check_matrixnormal_moment(mu, rowchol, colchol, size, expected): + with Model() as model: + MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size) + + @pytest.mark.parametrize( "alpha, mu, sigma, size, expected", [ @@ -886,6 +915,24 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "mu, rowchol, colchol, size, expected", + [ + (np.ones((1, 1)), np.eye(1), np.eye(1), None, np.ones((1, 1))), + (np.ones((1, 1)), np.eye(2), np.eye(3), None, np.ones((2, 3))), + (rand2d, np.eye(2), np.eye(3), None, rand2d), + (rand2d, np.eye(2), np.eye(3), 2, np.full((2, 2, 3), rand2d)), + (rand2d, np.eye(2), np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)), + ], +) +def test_matrixnormal_moment(mu, rowchol, colchol, size, expected): + if size is None: + check_matrixnormal_moment(mu, rowchol, colchol, size, expected) + else: + with pytest.raises(NotImplementedError): + check_matrixnormal_moment(mu, rowchol, colchol, size, expected) + + @pytest.mark.parametrize( "nu, sigma, size, expected", [