@@ -1445,7 +1445,7 @@ class MatrixNormal(Continuous):
1445
1445
1446
1446
.. math::
1447
1447
f(x \mid \mu, U, V) =
1448
- \frac{1}{(2\pi |U|^n |V|^m)^{1/2}}
1448
+ \frac{1}{(2\pi^{m n} |U|^n |V|^m)^{1/2}}
1449
1449
\exp\left\{
1450
1450
-\frac{1}{2} \mathrm{Tr}[ V^{-1} (x-\mu)^{\prime} U^{-1} (x-\mu)]
1451
1451
\right\}
@@ -1637,27 +1637,21 @@ def random(self, point=None, size=None):
1637
1637
mu , colchol , rowchol = draw_values (
1638
1638
[self .mu , self .colchol_cov , self .rowchol_cov ], point = point , size = size
1639
1639
)
1640
- if size is None :
1641
- size = ()
1642
- if size in (None , ()):
1643
- standard_normal = np .random .standard_normal ((self .shape [0 ], colchol .shape [- 1 ]))
1644
- samples = mu + np .matmul (rowchol , np .matmul (standard_normal , colchol .T ))
1645
- else :
1646
- samples = []
1647
- size = tuple (np .atleast_1d (size ))
1648
- if mu .shape == tuple (self .shape ):
1649
- for _ in range (np .prod (size )):
1650
- standard_normal = np .random .standard_normal ((self .shape [0 ], colchol .shape [- 1 ]))
1651
- samples .append (mu + np .matmul (rowchol , np .matmul (standard_normal , colchol .T )))
1652
- else :
1653
- for j in range (np .prod (size )):
1654
- standard_normal = np .random .standard_normal (
1655
- (self .shape [0 ], colchol [j ].shape [- 1 ])
1656
- )
1657
- samples .append (
1658
- mu [j ] + np .matmul (rowchol [j ], np .matmul (standard_normal , colchol [j ].T ))
1659
- )
1660
- samples = np .array (samples ).reshape (size + tuple (self .shape ))
1640
+ size = to_tuple (size )
1641
+ dist_shape = to_tuple (self .shape )
1642
+ output_shape = size + dist_shape
1643
+
1644
+ # Broadcasting all parameters
1645
+ mu = broadcast_dist_samples_to (to_shape = output_shape , samples = [mu ], size = size )[0 ]
1646
+ rowchol = np .broadcast_to (rowchol , shape = size + rowchol .shape [- 2 :])
1647
+
1648
+ colchol = np .broadcast_to (colchol , shape = size + colchol .shape [- 2 :])
1649
+ perm = np .arange (len (output_shape ))
1650
+ perm [- 2 :] = [perm [- 1 ], perm [- 2 ]]
1651
+ colchol = np .transpose (colchol , perm )
1652
+
1653
+ standard_normal = np .random .standard_normal (output_shape )
1654
+ samples = mu + np .matmul (rowchol , np .matmul (standard_normal , colchol ))
1661
1655
return samples
1662
1656
1663
1657
def _trquaddist (self , value ):
0 commit comments