Skip to content

Commit 927f81c

Browse files
committed
Fix MatrixNormal.random
1 parent 0402aab commit 927f81c

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

pymc3/distributions/multivariate.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,7 @@ class MatrixNormal(Continuous):
14451445
14461446
.. math::
14471447
f(x \mid \mu, U, V) =
1448-
\frac{1}{(2\pi |U|^n |V|^m)^{1/2}}
1448+
\frac{1}{(2\pi^{m n} |U|^n |V|^m)^{1/2}}
14491449
\exp\left\{
14501450
-\frac{1}{2} \mathrm{Tr}[ V^{-1} (x-\mu)^{\prime} U^{-1} (x-\mu)]
14511451
\right\}
@@ -1637,27 +1637,21 @@ def random(self, point=None, size=None):
16371637
mu, colchol, rowchol = draw_values(
16381638
[self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
16391639
)
1640-
if size is None:
1641-
size = ()
1642-
if size in (None, ()):
1643-
standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1]))
1644-
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T))
1645-
else:
1646-
samples = []
1647-
size = tuple(np.atleast_1d(size))
1648-
if mu.shape == tuple(self.shape):
1649-
for _ in range(np.prod(size)):
1650-
standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1]))
1651-
samples.append(mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T)))
1652-
else:
1653-
for j in range(np.prod(size)):
1654-
standard_normal = np.random.standard_normal(
1655-
(self.shape[0], colchol[j].shape[-1])
1656-
)
1657-
samples.append(
1658-
mu[j] + np.matmul(rowchol[j], np.matmul(standard_normal, colchol[j].T))
1659-
)
1660-
samples = np.array(samples).reshape(size + tuple(self.shape))
1640+
size = to_tuple(size)
1641+
dist_shape = to_tuple(self.shape)
1642+
output_shape = size + dist_shape
1643+
1644+
# Broadcasting all parameters
1645+
mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0]
1646+
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
1647+
1648+
colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
1649+
perm = np.arange(len(output_shape))
1650+
perm[-2:] = [perm[-1], perm[-2]]
1651+
colchol = np.transpose(colchol, perm)
1652+
1653+
standard_normal = np.random.standard_normal(output_shape)
1654+
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
16611655
return samples
16621656

16631657
def _trquaddist(self, value):

pymc3/tests/test_distributions_random.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,3 +1675,21 @@ def test_issue_3706(self):
16751675
prior_pred = pm.sample_prior_predictive(1)
16761676

16771677
assert prior_pred["X"].shape == (1, N, 2)
1678+
1679+
1680+
def test_issue_3585():
1681+
K = 3
1682+
D = 15
1683+
mu_0 = np.zeros((D, K))
1684+
lambd = 1.0
1685+
with pm.Model() as model:
1686+
sd_dist = pm.HalfCauchy.dist(beta=2.5)
1687+
packedL = pm.LKJCholeskyCov(f"packedL", eta=2, n=D, sd_dist=sd_dist)
1688+
L = pm.expand_packed_triangular(D, packedL, lower=True)
1689+
Sigma = pm.Deterministic(f"Sigma", L.dot(L.T)) # D x D covariance
1690+
mu = pm.MatrixNormal(
1691+
f"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
1692+
)
1693+
prior = pm.sample_prior_predictive(2)
1694+
1695+
assert prior["mu"].shape == (2, D, K)

0 commit comments

Comments
 (0)