Skip to content

Commit 827918b

Browse files
Improve multinomial moment (#6933)
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent cb64480 commit 827918b

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

pymc/distributions/multivariate.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,14 +540,24 @@ def dist(cls, n, p, *args, **kwargs):
540540

541541
def moment(rv, size, n, p):
542542
n = pt.shape_padright(n)
543-
mode = pt.round(n * p)
543+
mean = n * p
544+
mode = pt.round(mean)
545+
# Add correction term between n and approximation.
546+
# We modify highest expected entry to minimize chances of negative values.
544547
diff = n - pt.sum(mode, axis=-1, keepdims=True)
545-
inc_bool_arr = pt.abs(diff) > 0
546-
mode = pt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
548+
max_elem_idx = pt.argmax(mean, axis=-1, keepdims=True)
549+
mode = pt.inc_subtensor(
550+
pt.take_along_axis(mode, max_elem_idx, axis=-1),
551+
diff,
552+
)
547553
if not rv_size_is_none(size):
548554
output_size = pt.concatenate([size, [p.shape[-1]]])
549555
mode = pt.full(output_size, mode)
550-
return mode
556+
return Assert(
557+
"Negative value in computed moment of Multinomial."
558+
"It is a known limitation that can arise when the expected largest count is small."
559+
"Please provide an initial value manually."
560+
)(mode, pt.all(mode >= 0))
551561

552562
def logp(value, n, p):
553563
"""

tests/distributions/test_multivariate.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,18 +1013,18 @@ class TestMoments:
10131013
[
10141014
(np.array([0.25, 0.25, 0.25, 0.25]), 1, None, np.array([1, 0, 0, 0])),
10151015
(np.array([0.3, 0.6, 0.05, 0.05]), 2, None, np.array([1, 1, 0, 0])),
1016-
(np.array([0.3, 0.6, 0.05, 0.05]), 10, None, np.array([4, 6, 0, 0])),
1016+
(np.array([0.3, 0.6, 0.05, 0.05]), 10, None, np.array([3, 7, 0, 0])),
10171017
(
10181018
np.array([[0.3, 0.6, 0.05, 0.05], [0.25, 0.25, 0.25, 0.25]]),
10191019
10,
10201020
None,
1021-
np.array([[4, 6, 0, 0], [4, 2, 2, 2]]),
1021+
np.array([[3, 7, 0, 0], [4, 2, 2, 2]]),
10221022
),
10231023
(
10241024
np.array([0.3, 0.6, 0.05, 0.05]),
10251025
np.array([2, 10]),
10261026
(1, 2),
1027-
np.array([[[1, 1, 0, 0], [4, 6, 0, 0]]]),
1027+
np.array([[[1, 1, 0, 0], [3, 7, 0, 0]]]),
10281028
),
10291029
(
10301030
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
@@ -1038,6 +1038,21 @@ class TestMoments:
10381038
(3, 2),
10391039
np.full((3, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
10401040
),
1041+
(
1042+
np.array([0.0, 0.25, 0.25, 0.25, 0.25]),
1043+
1,
1044+
None,
1045+
np.array([0, 1, 0, 0, 0]),
1046+
),
1047+
pytest.param(
1048+
np.array([0.1441, 0.1363, 0.1385, 0.1348, 0.1521, 0.1500, 0.1442]),
1049+
4,
1050+
None,
1051+
np.array([1, 1, 1, 1, 0, 0, 0]),
1052+
marks=pytest.mark.xfail(
1053+
rises=AssertionError, reason="Known failure in mode approximation "
1054+
),
1055+
),
10411056
],
10421057
)
10431058
def test_multinomial_moment(self, p, n, size, expected):
@@ -1325,12 +1340,12 @@ def test_lkjcholeskycov_moment(self, n, eta, size, expected):
13251340
[
13261341
(np.array([2, 2, 2, 2]), 1, None, np.array([1, 0, 0, 0])),
13271342
(np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])),
1328-
(np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])),
1343+
(np.array([30, 60, 5, 5]), 10, None, np.array([3, 7, 0, 0])),
13291344
(
13301345
np.array([[30, 60, 5, 5], [26, 26, 26, 22]]),
13311346
10,
13321347
(1, 2),
1333-
np.array([[[4, 6, 0, 0], [2, 3, 3, 2]]]),
1348+
np.array([[[3, 7, 0, 0], [2, 3, 3, 2]]]),
13341349
),
13351350
(
13361351
np.array([26, 26, 26, 22]),

0 commit comments

Comments
 (0)