Skip to content

Commit bc09605

Browse files
committed
- Add bound to HyperGeometric logp
- Pass unit tests when scipy logpmf returns nan
1 parent 0ec65e5 commit bc09605

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

pymc3/distributions/discrete.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,9 @@ def logp(self, value):
930930
- betaln(n - value + 1, bad - n + value + 1)
931931
- betaln(tot + 1, 1)
932932
)
933-
return result
933+
lower = tt.max([0, n - N + k])
934+
upper = tt.min([k, n])
935+
return bound(result, lower <= value, value <= upper)
934936

935937

936938
class DiscreteUniform(Discrete):

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,11 +805,15 @@ def test_geometric(self):
805805
)
806806

807807
def test_hypergeometric(self):
808+
def modified_scipy_hypergeom_logpmf(value, N, k, n):
809+
original_res = sp.hypergeom.logpmf(value, N, k, n)
810+
return original_res if not np.isnan(original_res) else -np.inf
811+
808812
self.pymc3_matches_scipy(
809813
HyperGeometric,
810814
Nat,
811815
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
812-
lambda value, N, k, n: sp.hypergeom.logpmf(value, N, k, n),
816+
lambda value, N, k, n: modified_scipy_hypergeom_logpmf(value, N, k, n),
813817
)
814818

815819
def test_negative_binomial(self):

0 commit comments

Comments
 (0)