|
63 | 63 | ZeroInflatedPoisson,
|
64 | 64 | )
|
65 | 65 | from pymc.distributions.distribution import _get_moment, get_moment
|
| 66 | +from pymc.distributions.logprob import logpt |
66 | 67 | from pymc.distributions.multivariate import MvNormal
|
67 | 68 | from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
|
68 | 69 | from pymc.initial_point import make_initial_point_fn
|
@@ -145,20 +146,25 @@ def test_rv_size_is_none():
|
145 | 146 | assert not rv_size_is_none(rv.owner.inputs[1])
|
146 | 147 |
|
147 | 148 |
|
148 |
| -def assert_moment_is_expected(model, expected): |
| 149 | +def assert_moment_is_expected(model, expected, check_finite_logp=True): |
149 | 150 | fn = make_initial_point_fn(
|
150 | 151 | model=model,
|
151 | 152 | return_transformed=False,
|
152 | 153 | default_strategy="moment",
|
153 | 154 | )
|
154 |
| - result = fn(0)["x"] |
| 155 | + moment = fn(0)["x"] |
155 | 156 | expected = np.asarray(expected)
|
156 | 157 | try:
|
157 | 158 | random_draw = model["x"].eval()
|
158 | 159 | except NotImplementedError:
|
159 |
| - random_draw = result |
160 |
| - assert result.shape == expected.shape == random_draw.shape |
161 |
| - assert np.allclose(result, expected) |
| 160 | + random_draw = moment |
| 161 | + |
| 162 | + assert moment.shape == expected.shape == random_draw.shape |
| 163 | + assert np.allclose(moment, expected) |
| 164 | + |
| 165 | + if check_finite_logp: |
| 166 | + logp_moment = logpt(model["x"], at.constant(moment), transformed=False).eval() |
| 167 | + assert np.isfinite(logp_moment) |
162 | 168 |
|
163 | 169 |
|
164 | 170 | @pytest.mark.parametrize(
|
@@ -430,11 +436,11 @@ def test_lognormal_moment(mu, sigma, size, expected):
|
430 | 436 | [
|
431 | 437 | (1, None, 1),
|
432 | 438 | (1, 5, np.ones(5)),
|
433 |
| - (np.arange(5), None, np.arange(5)), |
| 439 | + (np.arange(1, 5), None, np.arange(1, 5)), |
434 | 440 | (
|
435 |
| - np.arange(5), |
436 |
| - (2, 5), |
437 |
| - np.full((2, 5), np.arange(5)), |
| 441 | + np.arange(1, 5), |
| 442 | + (2, 4), |
| 443 | + np.full((2, 4), np.arange(1, 5)), |
438 | 444 | ),
|
439 | 445 | ],
|
440 | 446 | )
|
@@ -676,11 +682,11 @@ def test_logistic_moment(mu, s, size, expected):
|
676 | 682 | @pytest.mark.parametrize(
|
677 | 683 | "mu, nu, sigma, size, expected",
|
678 | 684 | [
|
679 |
| - (1, 1, None, None, 2), |
| 685 | + (1, 1, 1, None, 2), |
680 | 686 | (1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)),
|
681 |
| - (1, 1, None, 5, np.full(5, 2)), |
682 |
| - (1, np.arange(1, 6), None, None, np.arange(2, 7)), |
683 |
| - (1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))), |
| 687 | + (1, 1, 3, 5, np.full(5, 2)), |
| 688 | + (1, np.arange(1, 6), 5, None, np.arange(2, 7)), |
| 689 | + (1, np.arange(1, 6), 1, (2, 5), np.full((2, 5), np.arange(2, 7))), |
684 | 690 | ],
|
685 | 691 | )
|
686 | 692 | def test_exgaussian_moment(mu, nu, sigma, size, expected):
|
@@ -920,8 +926,10 @@ def test_interpolated_moment(x_points, pdf_points, size, expected):
|
920 | 926 | )
|
921 | 927 | def test_mv_normal_moment(mu, cov, size, expected):
|
922 | 928 | with Model() as model:
|
923 |
| - MvNormal("x", mu=mu, cov=cov, size=size) |
924 |
| - assert_moment_is_expected(model, expected) |
| 929 | + x = MvNormal("x", mu=mu, cov=cov, size=size) |
| 930 | + |
| 931 | + # MvNormal logp is only impemented for up to 2D variables |
| 932 | + assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) |
925 | 933 |
|
926 | 934 |
|
927 | 935 | @pytest.mark.parametrize(
|
@@ -957,8 +965,10 @@ def test_moyal_moment(mu, sigma, size, expected):
|
957 | 965 | )
|
958 | 966 | def test_mvstudentt_moment(nu, mu, cov, size, expected):
|
959 | 967 | with Model() as model:
|
960 |
| - MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size) |
961 |
| - assert_moment_is_expected(model, expected) |
| 968 | + x = MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size) |
| 969 | + |
| 970 | + # MvStudentT logp is only impemented for up to 2D variables |
| 971 | + assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) |
962 | 972 |
|
963 | 973 |
|
964 | 974 | def check_matrixnormal_moment(mu, rowchol, colchol, size, expected):
|
@@ -1094,7 +1104,7 @@ def test_density_dist_default_moment_univariate(get_moment, size, expected):
|
1094 | 1104 | get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
|
1095 | 1105 | with Model() as model:
|
1096 | 1106 | DensityDist("x", get_moment=get_moment, size=size)
|
1097 |
| - assert_moment_is_expected(model, expected) |
| 1107 | + assert_moment_is_expected(model, expected, check_finite_logp=False) |
1098 | 1108 |
|
1099 | 1109 |
|
1100 | 1110 | @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
|
|
0 commit comments