Skip to content

Commit 4b6b77c

Browse files
add discrete uniform moment
1 parent d9ca35b commit 4b6b77c

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,12 @@ def dist(cls, lower, upper, *args, **kwargs):
10661066
upper = intX(at.floor(upper))
10671067
return super().dist([lower, upper], **kwargs)
10681068

1069+
def get_moment(rv, size, lower, upper):
1070+
mode = at.maximum(at.floor((upper + lower) / 2.0), lower)
1071+
if not rv_size_is_none(size):
1072+
mode = at.full(size, mode)
1073+
return mode
1074+
10691075
def logp(value, lower, upper):
10701076
r"""
10711077
Calculate log-probability of DiscreteUniform distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Cauchy,
1111
ChiSquared,
1212
Constant,
13+
DiscreteUniform,
1314
Exponential,
1415
Flat,
1516
Gamma,
@@ -540,3 +541,26 @@ def test_hyper_geometric_moment(N, k, n, size, expected):
540541
with Model() as model:
541542
HyperGeometric("x", N=N, k=k, n=n, size=size)
542543
assert_moment_is_expected(model, expected)
544+
<<<<<<< HEAD
545+
=======
546+
547+
548+
@pytest.mark.parametrize(
549+
"lower, upper, size, expected",
550+
[
551+
(1, 5, None, 3),
552+
(1, 5, 5, np.full(5, 3)),
553+
(1, np.arange(5, 22, 4), None, np.arange(3, 13, 2)),
554+
(
555+
1,
556+
np.arange(5, 22, 4),
557+
(2, 5),
558+
np.full((2, 5), np.arange(3, 13, 2)),
559+
),
560+
],
561+
)
562+
def test_discrete_uniform_moment(lower, upper, size, expected):
563+
with Model() as model:
564+
DiscreteUniform("x", lower=lower, upper=upper, size=size)
565+
assert_moment_is_expected(model, expected)
566+
>>>>>>> 530892a2... add discrete uniform moment

0 commit comments

Comments
 (0)