Skip to content

Commit 76e6463

Browse files
code format
1 parent 1202718 commit 76e6463

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,7 @@ def get_moment(rv, size, mu, covs, chols, evds):
19441944
moment_size = at.concatenate([size, mu.shape])
19451945
mean = at.full(moment_size, mu)
19461946
return mean
1947-
1947+
19481948
def logp(value, mu, sigma, *covs):
19491949
"""
19501950
Calculate log-probability of Multivariate Normal distribution

pymc/tests/test_distributions_moments.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
HyperGeometric,
3535
Interpolated,
3636
InverseGamma,
37-
Kumaraswamy,
3837
KroneckerNormal,
38+
Kumaraswamy,
3939
Laplace,
4040
Logistic,
4141
LogitNormal,
@@ -111,7 +111,6 @@ def test_all_distributions_have_moments():
111111
dist_module.discrete.DiscreteWeibull,
112112
dist_module.multivariate.CAR,
113113
dist_module.multivariate.DirichletMultinomial,
114-
dist_module.multivariate.KroneckerNormal,
115114
dist_module.multivariate.Wishart,
116115
}
117116

@@ -1318,26 +1317,28 @@ def normal_sim(rng, mu, sigma, size):
13181317

13191318
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
13201319

1320+
13211321
@pytest.mark.parametrize(
13221322
"mu, covs, size, expected",
13231323
[
13241324
(np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)),
1325-
(np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5,6))),
1326-
(np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6,6))),
1327-
(np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6,3))),
1328-
(np.zeros((4,6)), [np.identity(2),np.identity(3)], 6, np.zeros((6,4,6))),
1325+
(np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))),
1326+
(np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))),
1327+
(np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))),
13291328
(
1330-
np.array([[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]),
1329+
np.array([1, 2, 3, 4]),
13311330
[
1332-
np.array([[1., 0.5], [0.5, 2]]),
1333-
np.array([[1., 0.4], [0.4, 2]]),
1331+
np.array([[1.0, 0.5], [0.5, 2]]),
1332+
np.array([[1.0, 0.4], [0.4, 2]]),
13341333
],
13351334
2,
1336-
np.array([
1337-
[[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]],
1338-
[[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]
1339-
]),
1340-
)
1335+
np.array(
1336+
[
1337+
[1, 2, 3, 4],
1338+
[1, 2, 3, 4],
1339+
]
1340+
),
1341+
),
13411342
],
13421343
)
13431344
def test_kronecker_normal_moments(mu, covs, size, expected):

0 commit comments

Comments
 (0)