Skip to content

Commit 74160e9

Browse files
committed
Refactor TruncatedNormal get_moment test
1 parent b243827 commit 74160e9

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

pymc/distributions/continuous.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,16 +717,16 @@ def dist(
717717
def get_moment(rv, size, mu, sigma, lower, upper):
718718
mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper)
719719
moment = at.switch(
720-
at.isinf(lower),
720+
at.eq(lower, -np.inf),
721721
at.switch(
722-
at.isinf(upper),
722+
at.eq(upper, np.inf),
723723
# lower = -inf, upper = inf
724724
mu,
725725
# lower = -inf, upper = x
726726
upper - 1,
727727
),
728728
at.switch(
729-
at.isinf(upper),
729+
at.eq(upper, np.inf),
730730
# lower = x, upper = inf
731731
lower + 1,
732732
# lower = x, upper = x

pymc/tests/test_distributions_moments.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,13 @@ def test_halfstudentt_moment(nu, sigma, size, expected):
249249
assert_moment_is_expected(model, expected)
250250

251251

252-
@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
253252
@pytest.mark.parametrize(
254253
"mu, sigma, lower, upper, size, expected",
255254
[
256-
(0.9, 1, -1, 1, None, 0),
257-
(0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)),
255+
(0.9, 1, -5, 5, None, 0),
256+
(1, np.ones(5), -10, np.inf, None, np.full(5, -9)),
258257
(np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)),
259-
(1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)),
258+
(1, 1, [-np.inf, -np.inf, -np.inf], 10, None, np.full(3, 9)),
260259
],
261260
)
262261
def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected):

0 commit comments

Comments
 (0)