Skip to content

Commit 988f481

Browse files
n and p parametrization on Zero Inflated Negative Binomial (#5212)
* add n and p parametrization docstring and code * add n bound to logcdf Co-authored-by: Farhan Reynaldo <farhanreynaldo@gmail.com>
1 parent dc92865 commit 988f481

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

pymc/distributions/discrete.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,15 @@ def ZeroInfNegBinom(a, m, psi, x):
16291629
Var :math:`\psi\mu + \left (1 + \frac{\mu}{\alpha} + \frac{1-\psi}{\mu} \right)`
16301630
======== ==========================
16311631
1632+
The zero inflated negative binomial distribution can be parametrized
1633+
either in terms of mu or p, and either in terms of alpha or n.
1634+
The link between the parametrizations is given by
1635+
1636+
.. math::
1637+
1638+
\mu &= \frac{n(1-p)}{p} \\
1639+
\alpha &= n
1640+
16321641
Parameters
16331642
----------
16341643
psi: float
@@ -1637,15 +1646,18 @@ def ZeroInfNegBinom(a, m, psi, x):
16371646
Poission distribution parameter (mu > 0).
16381647
alpha: float
16391648
Gamma distribution parameter (alpha > 0).
1640-
1649+
p: float
1650+
Alternative probability of success in each trial (0 < p < 1).
1651+
n: float
1652+
Alternative number of target success trials (n > 0)
16411653
"""
16421654

16431655
rv_op = zero_inflated_neg_binomial
16441656

16451657
@classmethod
1646-
def dist(cls, psi, mu, alpha, *args, **kwargs):
1658+
def dist(cls, psi, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
16471659
psi = at.as_tensor_variable(floatX(psi))
1648-
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
1660+
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
16491661
n = at.as_tensor_variable(floatX(n))
16501662
p = at.as_tensor_variable(floatX(p))
16511663
return super().dist([psi, n, p], *args, **kwargs)
@@ -1707,6 +1719,7 @@ def logcdf(value, psi, n, p):
17071719
0 <= value,
17081720
0 <= psi,
17091721
psi <= 1,
1722+
0 < n,
17101723
0 < p,
17111724
p <= 1,
17121725
)

pymc/tests/test_distributions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,13 +1790,29 @@ def logcdf_fn(value, psi, mu, alpha):
17901790
logp_fn,
17911791
)
17921792

1793+
self.check_logp(
1794+
ZeroInflatedNegativeBinomial,
1795+
Nat,
1796+
{"psi": Unit, "p": Unit, "n": NatSmall},
1797+
lambda value, psi, p, n: np.log((1 - psi) * sp.nbinom.pmf(0, n, p))
1798+
if value == 0
1799+
else np.log(psi * sp.nbinom.pmf(value, n, p)),
1800+
)
1801+
17931802
self.check_logcdf(
17941803
ZeroInflatedNegativeBinomial,
17951804
Nat,
17961805
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
17971806
logcdf_fn,
17981807
)
17991808

1809+
self.check_logcdf(
1810+
ZeroInflatedNegativeBinomial,
1811+
Nat,
1812+
{"psi": Unit, "p": Unit, "n": NatSmall},
1813+
lambda value, psi, p, n: np.log((1 - psi) + psi * sp.nbinom.cdf(value, n, p)),
1814+
)
1815+
18001816
self.check_selfconsistency_discrete_logcdf(
18011817
ZeroInflatedNegativeBinomial,
18021818
Nat,

pymc/tests/test_distributions_random.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ def seeded_zero_inflated_binomial_rng_fn(self):
14671467
]
14681468

14691469

1470-
class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
1470+
class TestZeroInflatedNegativeBinomialMuSigma(BaseTestDistribution):
14711471
def zero_inflated_negbinomial_rng_fn(
14721472
self, size, psi, n, p, negbinomial_rng_fct, random_rng_fct
14731473
):
@@ -1502,6 +1502,14 @@ def seeded_zero_inflated_negbinomial_rng_fn(self):
15021502
]
15031503

15041504

1505+
class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
1506+
pymc_dist = pm.ZeroInflatedNegativeBinomial
1507+
pymc_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1508+
expected_rv_op_params = {"psi": 0.9, "n": 12, "p": 0.7}
1509+
reference_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1510+
tests_to_run = ["check_pymc_params_match_rv_op"]
1511+
1512+
15051513
class TestOrderedLogistic(BaseTestDistribution):
15061514
pymc_dist = _OrderedLogistic
15071515
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}

0 commit comments

Comments
 (0)