Skip to content

Commit d659848

Browse files
committed
Fix wrong ZeroSumNormal logp expression
1 parent 881ef46 commit d659848

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2781,6 +2781,7 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
27812781
(value,) = values
27822782
shape = value.shape
27832783
n_zerosum_axes = op.ndim_supp
2784+
*_, sigma = normal_dist.owner.inputs
27842785

27852786
_deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1)
27862787
_full_size = pm.floatX(pt.prod(shape))
@@ -2792,7 +2793,8 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
27922793
]
27932794

27942795
out = pt.sum(
2795-
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
2796+
-0.5 * pt.pow(value / sigma, 2)
2797+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) * _degrees_of_freedom / _full_size,
27962798
axis=tuple(np.arange(-n_zerosum_axes, 0)),
27972799
)
27982800

tests/distributions/test_multivariate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,7 +1596,7 @@ def test_zsn_variance(self, sigma, n):
15961596
(5, 3, None, [-1]),
15971597
(2, 6, None, [-1]),
15981598
(5, (7, 3), None, [-1]),
1599-
(5, (2, 7, 3), 2, [1, 2]),
1599+
(5, (2, 7, 3), 2, [-2, -1]),
16001600
],
16011601
)
16021602
def test_zsn_logp(self, sigma, shape, n_zerosum_axes, mvn_axes):
@@ -1629,8 +1629,9 @@ def logp_norm(value, sigma, axes):
16291629
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
16301630

16311631
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, n_zerosum_axes=n_zerosum_axes)
1632-
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
1633-
mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
1632+
zsn_draws = pm.draw(zsn_dist, 100)
1633+
zsn_logp = pm.logp(zsn_dist, value=zsn_draws).eval()
1634+
mvn_logp = logp_norm(value=zsn_draws, sigma=sigma, axes=mvn_axes)
16341635

16351636
np.testing.assert_allclose(zsn_logp, mvn_logp)
16361637

0 commit comments

Comments
 (0)