Skip to content

Commit 05113bb

Browse files
committed
Assert logp of moment is finite
1 parent 74160e9 commit 05113bb

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

pymc/tests/test_distributions_moments.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
ZeroInflatedPoisson,
6464
)
6565
from pymc.distributions.distribution import _get_moment, get_moment
66+
from pymc.distributions.logprob import logpt
6667
from pymc.distributions.multivariate import MvNormal
6768
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
6869
from pymc.initial_point import make_initial_point_fn
@@ -145,20 +146,25 @@ def test_rv_size_is_none():
145146
assert not rv_size_is_none(rv.owner.inputs[1])
146147

147148

148-
def assert_moment_is_expected(model, expected):
149+
def assert_moment_is_expected(model, expected, check_finite_logp=True):
149150
fn = make_initial_point_fn(
150151
model=model,
151152
return_transformed=False,
152153
default_strategy="moment",
153154
)
154-
result = fn(0)["x"]
155+
moment = fn(0)["x"]
155156
expected = np.asarray(expected)
156157
try:
157158
random_draw = model["x"].eval()
158159
except NotImplementedError:
159-
random_draw = result
160-
assert result.shape == expected.shape == random_draw.shape
161-
assert np.allclose(result, expected)
160+
random_draw = moment
161+
162+
assert moment.shape == expected.shape == random_draw.shape
163+
assert np.allclose(moment, expected)
164+
165+
if check_finite_logp:
166+
logp_moment = logpt(model["x"], at.constant(moment), transformed=False).eval()
167+
assert np.isfinite(logp_moment)
162168

163169

164170
@pytest.mark.parametrize(
@@ -430,11 +436,11 @@ def test_lognormal_moment(mu, sigma, size, expected):
430436
[
431437
(1, None, 1),
432438
(1, 5, np.ones(5)),
433-
(np.arange(5), None, np.arange(5)),
439+
(np.arange(1, 5), None, np.arange(1, 5)),
434440
(
435-
np.arange(5),
436-
(2, 5),
437-
np.full((2, 5), np.arange(5)),
441+
np.arange(1, 5),
442+
(2, 4),
443+
np.full((2, 4), np.arange(1, 5)),
438444
),
439445
],
440446
)
@@ -676,11 +682,11 @@ def test_logistic_moment(mu, s, size, expected):
676682
@pytest.mark.parametrize(
677683
"mu, nu, sigma, size, expected",
678684
[
679-
(1, 1, None, None, 2),
685+
(1, 1, 1, None, 2),
680686
(1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)),
681-
(1, 1, None, 5, np.full(5, 2)),
682-
(1, np.arange(1, 6), None, None, np.arange(2, 7)),
683-
(1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))),
687+
(1, 1, 3, 5, np.full(5, 2)),
688+
(1, np.arange(1, 6), 5, None, np.arange(2, 7)),
689+
(1, np.arange(1, 6), 1, (2, 5), np.full((2, 5), np.arange(2, 7))),
684690
],
685691
)
686692
def test_exgaussian_moment(mu, nu, sigma, size, expected):
@@ -920,8 +926,10 @@ def test_interpolated_moment(x_points, pdf_points, size, expected):
920926
)
921927
def test_mv_normal_moment(mu, cov, size, expected):
922928
with Model() as model:
923-
MvNormal("x", mu=mu, cov=cov, size=size)
924-
assert_moment_is_expected(model, expected)
929+
x = MvNormal("x", mu=mu, cov=cov, size=size)
930+
931+
# MvNormal logp is only impemented for up to 2D variables
932+
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
925933

926934

927935
@pytest.mark.parametrize(
@@ -957,8 +965,10 @@ def test_moyal_moment(mu, sigma, size, expected):
957965
)
958966
def test_mvstudentt_moment(nu, mu, cov, size, expected):
959967
with Model() as model:
960-
MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size)
961-
assert_moment_is_expected(model, expected)
968+
x = MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size)
969+
970+
# MvStudentT logp is only impemented for up to 2D variables
971+
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
962972

963973

964974
def check_matrixnormal_moment(mu, rowchol, colchol, size, expected):
@@ -1094,7 +1104,7 @@ def test_density_dist_default_moment_univariate(get_moment, size, expected):
10941104
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
10951105
with Model() as model:
10961106
DensityDist("x", get_moment=get_moment, size=size)
1097-
assert_moment_is_expected(model, expected)
1107+
assert_moment_is_expected(model, expected, check_finite_logp=False)
10981108

10991109

11001110
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)

0 commit comments

Comments
 (0)