|
34 | 34 | HyperGeometric,
|
35 | 35 | Interpolated,
|
36 | 36 | InverseGamma,
|
37 |
| - Kumaraswamy, |
38 | 37 | KroneckerNormal,
|
| 38 | + Kumaraswamy, |
39 | 39 | Laplace,
|
40 | 40 | Logistic,
|
41 | 41 | LogitNormal,
|
@@ -111,7 +111,6 @@ def test_all_distributions_have_moments():
|
111 | 111 | dist_module.discrete.DiscreteWeibull,
|
112 | 112 | dist_module.multivariate.CAR,
|
113 | 113 | dist_module.multivariate.DirichletMultinomial,
|
114 |
| - dist_module.multivariate.KroneckerNormal, |
115 | 114 | dist_module.multivariate.Wishart,
|
116 | 115 | }
|
117 | 116 |
|
@@ -1318,26 +1317,28 @@ def normal_sim(rng, mu, sigma, size):
|
1318 | 1317 |
|
1319 | 1318 | assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
|
1320 | 1319 |
|
| 1320 | + |
1321 | 1321 | @pytest.mark.parametrize(
|
1322 | 1322 | "mu, covs, size, expected",
|
1323 | 1323 | [
|
1324 | 1324 | (np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)),
|
1325 |
| - (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5,6))), |
1326 |
| - (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6,6))), |
1327 |
| - (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6,3))), |
1328 |
| - (np.zeros((4,6)), [np.identity(2),np.identity(3)], 6, np.zeros((6,4,6))), |
| 1325 | + (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))), |
| 1326 | + (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))), |
| 1327 | + (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))), |
1329 | 1328 | (
|
1330 |
| - np.array([[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]]), |
| 1329 | + np.array([1, 2, 3, 4]), |
1331 | 1330 | [
|
1332 |
| - np.array([[1., 0.5], [0.5, 2]]), |
1333 |
| - np.array([[1., 0.4], [0.4, 2]]), |
| 1331 | + np.array([[1.0, 0.5], [0.5, 2]]), |
| 1332 | + np.array([[1.0, 0.4], [0.4, 2]]), |
1334 | 1333 | ],
|
1335 | 1334 | 2,
|
1336 |
| - np.array([ |
1337 |
| - [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]], |
1338 |
| - [[1,2,3,4],[3,4,5,6],[6,7,8,9],[7,8,9,1]] |
1339 |
| - ]), |
1340 |
| - ) |
| 1335 | + np.array( |
| 1336 | + [ |
| 1337 | + [1, 2, 3, 4], |
| 1338 | + [1, 2, 3, 4], |
| 1339 | + ] |
| 1340 | + ), |
| 1341 | + ), |
1341 | 1342 | ],
|
1342 | 1343 | )
|
1343 | 1344 | def test_kronecker_normal_moments(mu, covs, size, expected):
|
|
0 commit comments