Skip to content

Commit c983aff

Browse files
committed
Default zero mu for MvNormal and MvStudentT
1 parent 8a0bf14 commit c983aff

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class MvNormal(Continuous):
238238
rv_op = multivariate_normal
239239

240240
@classmethod
241-
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
241+
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
242242
mu = pt.as_tensor_variable(mu)
243243
cov = quaddist_matrix(cov, chol, tau, lower)
244244
# PyTensor is stricter about the shape of mu, than PyMC used to be
@@ -358,7 +358,7 @@ class MvStudentT(Continuous):
358358
rv_op = mv_studentt
359359

360360
@classmethod
361-
def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs):
361+
def dist(cls, nu, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs):
362362
cov = kwargs.pop("cov", None)
363363
if cov is not None:
364364
warnings.warn(

tests/distributions/test_multivariate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,7 +2300,11 @@ def test_mvnormal_no_cholesky_in_model_logp():
23002300

23012301

23022302
def test_mvnormal_mu_convenience():
2303-
"""Test that mu is broadcasted to the length of cov"""
2303+
"""Test that mu is broadcasted to the length of cov and provided a default of zero"""
2304+
x = pm.MvNormal.dist(cov=np.eye(3))
2305+
mu = x.owner.inputs[3]
2306+
np.testing.assert_allclose(mu.eval(), np.zeros((3,)))
2307+
23042308
x = pm.MvNormal.dist(mu=1, cov=np.eye(3))
23052309
mu = x.owner.inputs[3]
23062310
np.testing.assert_allclose(mu.eval(), np.ones((3,)))
@@ -2325,7 +2329,11 @@ def test_mvnormal_mu_convenience():
23252329

23262330

23272331
def test_mvstudentt_mu_convenience():
2328-
"""Test that mu is broadcasted to the length of scale"""
2332+
"""Test that mu is broadcasted to the length of scale and provided a default of zero"""
2333+
x = pm.MvStudentT.dist(nu=4, scale=np.eye(3))
2334+
mu = x.owner.inputs[4]
2335+
np.testing.assert_allclose(mu.eval(), np.zeros((3,)))
2336+
23292337
x = pm.MvStudentT.dist(nu=4, mu=1, scale=np.eye(3))
23302338
mu = x.owner.inputs[4]
23312339
np.testing.assert_allclose(mu.eval(), np.ones((3,)))

0 commit comments

Comments
 (0)