Skip to content

Commit 1202718

Browse files
moments
1 parent 36c7553 commit 1202718

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,13 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
19381938
# mean = median = mode = mu
19391939
return super().dist([mu, sigma, *covs], **kwargs)
19401940

1941+
def get_moment(rv, size, mu, covs, chols, evds):
1942+
mean = mu
1943+
if not rv_size_is_none(size):
1944+
moment_size = at.concatenate([size, mu.shape])
1945+
mean = at.full(moment_size, mu)
1946+
return mean
1947+
19411948
def logp(value, mu, sigma, *covs):
19421949
"""
19431950
Calculate log-probability of Multivariate Normal distribution

pymc/tests/test_distributions_moments.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Interpolated,
3636
InverseGamma,
3737
Kumaraswamy,
38+
KroneckerNormal,
3839
Laplace,
3940
Logistic,
4041
LogitNormal,
@@ -1316,3 +1317,30 @@ def normal_sim(rng, mu, sigma, size):
13161317
cutoff = st.norm().ppf(1 - (alpha / 2))
13171318

13181319
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
1320+
1321+
@pytest.mark.parametrize(
1322+
"mu, covs, size, expected",
1323+
[
1324+
(np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)),
1325+
(np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5,6))),
1326+
(np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6,6))),
1327+
(np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6,3))),
1328+
(np.zeros((4,6)), [np.identity(2),np.identity(3)], 6, np.zeros((6,4,6))),
1329+
(
1330+
np.array([[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]),
1331+
[
1332+
np.array([[1., 0.5], [0.5, 2]]),
1333+
np.array([[1., 0.4], [0.4, 2]]),
1334+
],
1335+
2,
1336+
np.array([
1337+
[[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]],
1338+
[[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]
1339+
]),
1340+
)
1341+
],
1342+
)
1343+
def test_kronecker_normal_moments(mu, covs, size, expected):
1344+
with Model() as model:
1345+
KroneckerNormal("x", mu=mu, covs=covs, size=size)
1346+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)