Skip to content

Commit d9ca35b

Browse files
add hyper geometric moment
1 parent 7745f55 commit d9ca35b

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,12 @@ def dist(cls, N, k, n, *args, **kwargs):
926926
n = at.as_tensor_variable(intX(n))
927927
return super().dist([good, bad, n], *args, **kwargs)
928928

929+
def get_moment(rv, size, N, k, n):
930+
mode = intX(at.floor((n + 1) * (k + 1) / (N + 2)))
931+
if not rv_size_is_none(size):
932+
mode = at.full(size, mode)
933+
return mode
934+
929935
def logp(value, good, bad, n):
930936
r"""
931937
Calculate log-probability of HyperGeometric distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
HalfFlat,
1919
HalfNormal,
2020
HalfStudentT,
21+
HyperGeometric,
2122
Kumaraswamy,
2223
Laplace,
2324
Logistic,
@@ -157,7 +158,9 @@ def test_halfstudentt_moment(nu, sigma, size, expected):
157158
assert_moment_is_expected(model, expected)
158159

159160

160-
@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
161+
@pytest.mark.skip(
162+
reason="aeppl interval transform fails when both edges are None"
163+
)
161164
@pytest.mark.parametrize(
162165
"mu, sigma, lower, upper, size, expected",
163166
[
@@ -169,7 +172,9 @@ def test_halfstudentt_moment(nu, sigma, size, expected):
169172
)
170173
def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected):
171174
with Model() as model:
172-
TruncatedNormal("x", mu=mu, sigma=sigma, lower=lower, upper=upper, size=size)
175+
TruncatedNormal(
176+
"x", mu=mu, sigma=sigma, lower=lower, upper=upper, size=size
177+
)
173178
assert_moment_is_expected(model, expected)
174179

175180

@@ -417,7 +422,12 @@ def test_poisson_moment(mu, size, expected):
417422
(10, 0.7, None, 4),
418423
(10, 0.7, 5, np.full(5, 4)),
419424
(np.full(3, 10), np.arange(1, 4) / 10, None, np.array([90, 40, 23])),
420-
(10, np.arange(1, 4) / 10, (2, 3), np.full((2, 3), np.array([90, 40, 23]))),
425+
(
426+
10,
427+
np.arange(1, 4) / 10,
428+
(2, 3),
429+
np.full((2, 3), np.array([90, 40, 23])),
430+
),
421431
],
422432
)
423433
def test_negative_binomial_moment(n, p, size, expected):
@@ -461,7 +471,13 @@ def test_zero_inflated_poisson_moment(psi, theta, size, expected):
461471
(0.2, 7, 0.7, None, 4),
462472
(0.2, 7, 0.3, 5, np.full(5, 2)),
463473
(0.6, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)),
464-
(0.6, 25, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))),
474+
(
475+
0.6,
476+
25,
477+
np.arange(1, 6) / 10,
478+
(2, 5),
479+
np.full((2, 5), np.arange(1, 6)),
480+
),
465481
],
466482
)
467483
def test_zero_inflated_binomial_moment(psi, n, p, size, expected):
@@ -503,3 +519,24 @@ def test_geometric_moment(p, size, expected):
503519
with Model() as model:
504520
Geometric("x", p=p, size=size)
505521
assert_moment_is_expected(model, expected)
522+
523+
524+
@pytest.mark.parametrize(
525+
"N, k, n, size, expected",
526+
[
527+
(50, 10, 20, None, 4),
528+
(50, 10, 23, 5, np.full(5, 5)),
529+
(50, 10, np.arange(23, 28), None, np.full(5, 5)),
530+
(
531+
50,
532+
10,
533+
np.arange(18, 23),
534+
(2, 5),
535+
np.full((2, 5), 4),
536+
),
537+
],
538+
)
539+
def test_hyper_geometric_moment(N, k, n, size, expected):
540+
with Model() as model:
541+
HyperGeometric("x", N=N, k=k, n=n, size=size)
542+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)