Skip to content

Commit dc92865

Browse files
authored
Add MvStudentT and MatrixNormal moment (#5173)
1 parent 64c1464 commit dc92865

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True
354354
assert_negative_support(nu, "nu", "MvStudentT")
355355
return super().dist([nu, mu, cov], **kwargs)
356356

357+
def get_moment(rv, size, nu, mu, cov):
358+
moment = mu
359+
if not rv_size_is_none(size):
360+
moment_size = at.concatenate([size, moment.shape])
361+
moment = at.full(moment_size, moment)
362+
return moment
363+
357364
def logp(value, nu, mu, cov):
358365
"""
359366
Calculate log-probability of Multivariate Student's T distribution
@@ -692,7 +699,7 @@ def dist(cls, eta, cutpoints, n, *args, **kwargs):
692699

693700

694701
class OrderedMultinomial:
695-
R"""
702+
r"""
696703
Wrapper class for Ordered Multinomial distributions.
697704
698705
Useful for regression on ordinal data whose values range
@@ -1727,6 +1734,13 @@ def dist(
17271734

17281735
return super().dist([mu, rowchol_cov, colchol_cov], **kwargs)
17291736

1737+
def get_moment(rv, size, mu, rowchol, colchol):
1738+
output_shape = (rowchol.shape[0], colchol.shape[0])
1739+
if not rv_size_is_none(size):
1740+
output_shape = at.concatenate([size, output_shape])
1741+
moment = at.full(output_shape, mu)
1742+
return moment
1743+
17301744
def logp(value, mu, rowchol, colchol):
17311745
"""
17321746
Calculate log-probability of Matrix-valued Normal distribution

pymc/tests/test_distributions_moments.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
Logistic,
4040
LogitNormal,
4141
LogNormal,
42+
MatrixNormal,
4243
Moyal,
44+
MvStudentT,
4345
NegativeBinomial,
4446
Normal,
4547
Pareto,
@@ -830,6 +832,33 @@ def test_moyal_moment(mu, sigma, size, expected):
830832
assert_moment_is_expected(model, expected)
831833

832834

835+
rand1d = np.random.rand(2)
836+
rand2d = np.random.rand(2, 3)
837+
838+
839+
@pytest.mark.parametrize(
840+
"nu, mu, cov, size, expected",
841+
[
842+
(2, np.ones(1), np.eye(1), None, np.ones(1)),
843+
(2, rand1d, np.eye(2), None, rand1d),
844+
(2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)),
845+
(2, rand1d, np.eye(2), (2, 5), np.full((2, 5, 2), rand1d)),
846+
(2, rand2d, np.eye(3), None, rand2d),
847+
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)),
848+
(2, rand2d, np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)),
849+
],
850+
)
851+
def test_mvstudentt_moment(nu, mu, cov, size, expected):
852+
with Model() as model:
853+
MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size)
854+
assert_moment_is_expected(model, expected)
855+
856+
857+
def check_matrixnormal_moment(mu, rowchol, colchol, size, expected):
858+
with Model() as model:
859+
MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size)
860+
861+
833862
@pytest.mark.parametrize(
834863
"alpha, mu, sigma, size, expected",
835864
[
@@ -886,6 +915,24 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
886915
assert_moment_is_expected(model, expected)
887916

888917

918+
@pytest.mark.parametrize(
919+
"mu, rowchol, colchol, size, expected",
920+
[
921+
(np.ones((1, 1)), np.eye(1), np.eye(1), None, np.ones((1, 1))),
922+
(np.ones((1, 1)), np.eye(2), np.eye(3), None, np.ones((2, 3))),
923+
(rand2d, np.eye(2), np.eye(3), None, rand2d),
924+
(rand2d, np.eye(2), np.eye(3), 2, np.full((2, 2, 3), rand2d)),
925+
(rand2d, np.eye(2), np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)),
926+
],
927+
)
928+
def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
929+
if size is None:
930+
check_matrixnormal_moment(mu, rowchol, colchol, size, expected)
931+
else:
932+
with pytest.raises(NotImplementedError):
933+
check_matrixnormal_moment(mu, rowchol, colchol, size, expected)
934+
935+
889936
@pytest.mark.parametrize(
890937
"nu, sigma, size, expected",
891938
[

0 commit comments

Comments
 (0)