Skip to content

Commit 421cac1

Browse files
committed
fix str repr for KroneckerNormal and MatrixNormal
1 parent 3a9c835 commit 421cac1

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

pymc3/distributions/multivariate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,8 @@ def logp(self, x):
14491449
broadcast_conditions=False,
14501450
)
14511451

1452+
def _distr_parameters_for_repr(self):
1453+
return ["eta", "n"]
14521454

14531455
class MatrixNormal(Continuous):
14541456
R"""
@@ -1712,6 +1714,10 @@ def logp(self, value):
17121714
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
17131715
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
17141716

1717+
def _distr_parameters_for_repr(self):
1718+
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
1719+
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
1720+
17151721

17161722
class KroneckerNormal(Continuous):
17171723
R"""
@@ -1954,3 +1960,6 @@ def logp(self, value):
19541960
"""
19551961
quad, logdet = self._quaddist(value)
19561962
return -(quad + logdet + self.N * tt.log(2 * np.pi)) / 2.0
1963+
1964+
def _distr_parameters_for_repr(self):
1965+
return ["mu"]

pymc3/tests/test_distributions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,8 +1782,18 @@ def setup_class(self):
17821782
# add a bounded variable as well
17831783
bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)
17841784

1785+
# KroneckerNormal
1786+
n, m = 3, 4
1787+
covs = [np.eye(n), np.eye(m)]
1788+
kron_normal = KroneckerNormal('kron_normal', mu=np.zeros(n*m), covs=covs, shape=n*m)
1789+
1790+
# MatrixNormal
1791+
matrix_normal = MatrixNormal('mat_normal', mu=np.random.normal(size=n), rowcov=np.eye(n),
1792+
colchol=np.linalg.cholesky(np.eye(n)), shape=(n, n))
1793+
17851794
# Likelihood (sampling distribution) of observations
17861795
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
1796+
17871797
self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var]
17881798
self.expected_latex = (
17891799
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
@@ -1793,6 +1803,8 @@ def setup_class(self):
17931803
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
17941804
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
17951805
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1806+
r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$",
1807+
r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$",
17961808
)
17971809
self.expected_str = (
17981810
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
@@ -1802,6 +1814,8 @@ def setup_class(self):
18021814
r"Z ~ MvNormal(mu=array, chol_cov=array)",
18031815
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
18041816
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
1817+
r"kron_normal ~ KroneckerNormal(mu=array)",
1818+
r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)",
18051819
)
18061820

18071821
def test__repr_latex_(self):

0 commit comments

Comments
 (0)