From 1202718b80bef0b0c486fab7fcb0904713b3c3fc Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Thu, 2 Dec 2021 17:30:11 +0530 Subject: [PATCH 1/2] moments --- pymc/distributions/multivariate.py | 7 ++++++ pymc/tests/test_distributions_moments.py | 28 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 0b928f229d..2a46b59e58 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1938,6 +1938,13 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs) # mean = median = mode = mu return super().dist([mu, sigma, *covs], **kwargs) + def get_moment(rv, size, mu, covs, chols, evds): + mean = mu + if not rv_size_is_none(size): + moment_size = at.concatenate([size, mu.shape]) + mean = at.full(moment_size, mu) + return mean + def logp(value, mu, sigma, *covs): """ Calculate log-probability of Multivariate Normal distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 2188a931c4..88b50e7350 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -35,6 +35,7 @@ Interpolated, InverseGamma, Kumaraswamy, + KroneckerNormal, Laplace, Logistic, LogitNormal, @@ -1316,3 +1317,30 @@ def normal_sim(rng, mu, sigma, size): cutoff = st.norm().ppf(1 - (alpha / 2)) assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff) + +@pytest.mark.parametrize( + "mu, covs, size, expected", + [ + (np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)), + (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5,6))), + (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6,6))), + (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6,3))), + (np.zeros((4,6)), [np.identity(2),np.identity(3)], 6, np.zeros((6,4,6))), + ( + np.array([[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]), + [ + np.array([[1., 0.5], [0.5, 2]]), + np.array([[1., 0.4], [0.4, 2]]), + ], + 2, + np.array([ + [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]], + [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]] + ]), + ) + ], +) +def test_kronecker_normal_moments(mu, covs, size, expected): + with Model() as model: + KroneckerNormal("x", mu=mu, covs=covs, size=size) + assert_moment_is_expected(model, expected) From 76e64631f798bc34c09a6723d97c624ad173399a Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Mon, 13 Dec 2021 17:50:25 +0530 Subject: [PATCH 2/2] code format --- pymc/distributions/multivariate.py | 2 +- pymc/tests/test_distributions_moments.py | 29 ++++++++++++------------ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 2a46b59e58..f0f031ff02 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1944,7 +1944,7 @@ def get_moment(rv, size, mu, covs, chols, evds): moment_size = at.concatenate([size, mu.shape]) mean = at.full(moment_size, mu) return mean - + def logp(value, mu, sigma, *covs): """ Calculate log-probability of Multivariate Normal distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 88b50e7350..4a6361504e 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -34,8 +34,8 @@ HyperGeometric, Interpolated, InverseGamma, - Kumaraswamy, KroneckerNormal, + Kumaraswamy, Laplace, Logistic, LogitNormal, @@ -111,7 +111,6 @@ def test_all_distributions_have_moments(): dist_module.discrete.DiscreteWeibull, dist_module.multivariate.CAR, dist_module.multivariate.DirichletMultinomial, - dist_module.multivariate.KroneckerNormal, dist_module.multivariate.Wishart, } @@ -1318,26 +1317,28 @@ def normal_sim(rng, mu, sigma, size): assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff) + @pytest.mark.parametrize( "mu, covs, size, expected", [ (np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)), - (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5,6))), - (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6,6))), - (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6,3))), - (np.zeros((4,6)), [np.identity(2),np.identity(3)], 6, np.zeros((6,4,6))), + (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))), + (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))), + (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))), ( - np.array([[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]), + np.array([1, 2, 3, 4]), [ - np.array([[1., 0.5], [0.5, 2]]), - np.array([[1., 0.4], [0.4, 2]]), + np.array([[1.0, 0.5], [0.5, 2]]), + np.array([[1.0, 0.4], [0.4, 2]]), ], 2, - np.array([ - [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]], - [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]] - ]), - ) + np.array( + [ + [1, 2, 3, 4], + [1, 2, 3, 4], + ] + ), + ), ], ) def test_kronecker_normal_moments(mu, covs, size, expected):