Skip to content

Commit 3ff07d4

Browse files
Add GeneralizedPoisson distribution
Co-authored-by: Luciano Paz <luciano.paz.neuro@gmail.com>
1 parent 5f1c2bb commit 3ff07d4

File tree

4 files changed

+292
-0
lines changed

4 files changed

+292
-0
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Distributions
2929
:toctree: generated/
3030

3131
GenExtreme
32+
GeneralizedPoisson
3233
histogram_utils.histogram_approximation
3334

3435

pymc_experimental/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919

2020
from pymc_experimental.distributions.continuous import GenExtreme
21+
from pymc_experimental.distributions.discrete import GeneralizedPoisson
2122

2223
__all__ = [
2324
"GenExtreme",
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pymc as pm
17+
from pymc.distributions.dist_math import check_parameters, factln, logpow
18+
from pymc.distributions.shape_utils import rv_size_is_none
19+
from pytensor import tensor as pt
20+
from pytensor.tensor.random.op import RandomVariable
21+
22+
23+
class GeneralizedPoissonRV(RandomVariable):
24+
name = "generalized_poisson"
25+
ndim_supp = 0
26+
ndims_params = [0, 0]
27+
dtype = "int64"
28+
_print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")
29+
30+
@classmethod
31+
def rng_fn(cls, rng, theta, lam, size):
32+
theta = np.asarray(theta)
33+
lam = np.asarray(lam)
34+
35+
if size is not None:
36+
dist_size = size
37+
else:
38+
dist_size = np.broadcast_shapes(theta.shape, lam.shape)
39+
40+
# A mix of 2 algorithms described by Famoye (1997) is used depending on parameter values
41+
# 0: Inverse method, computed on the log scale. Used when lam <= 0.
42+
# 1: Branching method. Used when lambda > 0.
43+
x = np.empty(dist_size)
44+
idxs_mask = np.broadcast_to(lam < 0, dist_size)
45+
if np.any(idxs_mask):
46+
x[idxs_mask] = cls._inverse_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
47+
idxs_mask
48+
]
49+
idxs_mask = ~idxs_mask
50+
if np.any(idxs_mask):
51+
x[idxs_mask] = cls._branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
52+
idxs_mask
53+
]
54+
return x
55+
56+
@classmethod
57+
def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
58+
log_u = np.log(rng.uniform(size=dist_size))
59+
pos_lam = lam > 0
60+
abs_log_lam = np.log(np.abs(lam))
61+
theta_m_lam = theta - lam
62+
log_s = -theta
63+
log_p = log_s.copy()
64+
x_ = 0
65+
x = np.zeros(dist_size)
66+
below_cutpoint = log_s < log_u
67+
with np.errstate(divide="ignore", invalid="ignore"):
68+
while np.any(below_cutpoint[idxs_mask]):
69+
x_ += 1
70+
x[below_cutpoint] += 1
71+
log_c = np.log(theta_m_lam + lam * x_)
72+
# Compute log(1 + lam / C)
73+
log1p_lam_m_C = np.where(
74+
pos_lam,
75+
np.log1p(np.exp(abs_log_lam - log_c)),
76+
pm.math.log1mexp_numpy(abs_log_lam - log_c, negative_input=True),
77+
)
78+
log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam
79+
log_s = np.logaddexp(log_s, log_p)
80+
below_cutpoint = log_s < log_u
81+
return x
82+
83+
@classmethod
84+
def _branching_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
85+
lam_ = np.abs(lam) # This algorithm is only valid for positive lam
86+
y = rng.poisson(theta, size=dist_size)
87+
x = y.copy()
88+
higher_than_zero = y > 0
89+
while np.any(higher_than_zero[idxs_mask]):
90+
y = rng.poisson(lam_ * y)
91+
x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero]
92+
higher_than_zero = y > 0
93+
return x
94+
95+
96+
generalized_poisson = GeneralizedPoissonRV()
97+
98+
99+
class GeneralizedPoisson(pm.distributions.Discrete):
100+
R"""
101+
Generalized Poisson.
102+
Used to model count data that can be either overdispersed or underdispersed.
103+
Offers greater flexibility than the standard Poisson which assumes equidispersion,
104+
where the mean is equal to the variance.
105+
The pmf of this distribution is
106+
107+
.. math:: f(x \mid \mu, \lambda) =
108+
\frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
109+
======== ======================================
110+
Support :math:`x \in \mathbb{N}_0`
111+
Mean :math:`\frac{\mu}{1 - \lambda}`
112+
Variance :math:`\frac{\mu}{(1 - \lambda)^3}`
113+
======== ======================================
114+
115+
Parameters
116+
----------
117+
mu : tensor_like of float
118+
Mean parameter (mu > 0).
119+
lam : tensor_like of float
120+
Dispersion parameter (max(-1, -mu/4) <= lam <= 1).
121+
122+
Notes
123+
-----
124+
When lam = 0, the Generalized Poisson reduces to the standard Poisson with the same mu.
125+
When lam < 0, the mean is greater than the variance (underdispersion).
126+
When lam > 0, the mean is less than the variance (overdispersion).
127+
128+
References
129+
----------
130+
The PMF is taken from [1] and the random generator function is adapted from [2].
131+
.. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
132+
Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
133+
.. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
134+
Journal of Mathematical and Management Sciences 17.3-4 (1997): 219-237.
135+
"""
136+
137+
rv_op = generalized_poisson
138+
139+
@classmethod
140+
def dist(cls, mu, lam, **kwargs):
141+
mu = pt.as_tensor_variable(mu)
142+
lam = pt.as_tensor_variable(lam)
143+
return super().dist([mu, lam], **kwargs)
144+
145+
def moment(rv, size, mu, lam):
146+
mean = pt.floor(mu / (1 - lam))
147+
if not rv_size_is_none(size):
148+
mean = pt.full(size, mean)
149+
return mean
150+
151+
def logp(value, mu, lam):
152+
mu_lam_value = mu + lam * value
153+
logprob = np.log(mu) + logpow(mu_lam_value, value - 1) - mu_lam_value - factln(value)
154+
155+
# Probability is 0 when value > m, where m is the largest positive integer for
156+
# which mu + m * lam > 0 (when lam < 0).
157+
logprob = pt.switch(
158+
pt.or_(
159+
mu_lam_value < 0,
160+
value < 0,
161+
),
162+
-np.inf,
163+
logprob,
164+
)
165+
166+
return check_parameters(
167+
logprob,
168+
0 < mu,
169+
pt.abs(lam) <= 1,
170+
(-mu / 4) <= lam,
171+
msg="0 < mu, max(-1, -mu/4)) <= lam <= 1",
172+
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pymc as pm
16+
import pytensor
17+
import pytensor.tensor as pt
18+
import pytest
19+
import scipy.stats
20+
from pymc.logprob.utils import ParameterValueError
21+
from pymc.testing import (
22+
BaseTestDistributionRandom,
23+
Domain,
24+
Rplus,
25+
assert_moment_is_expected,
26+
discrete_random_tester,
27+
)
28+
29+
from pymc_experimental.distributions import GeneralizedPoisson
30+
31+
32+
class TestGeneralizedPoisson:
33+
class TestRandomVariable(BaseTestDistributionRandom):
34+
pymc_dist = GeneralizedPoisson
35+
pymc_dist_params = {"mu": 4.0, "lam": 1.0}
36+
expected_rv_op_params = {"mu": 4.0, "lam": 1.0}
37+
tests_to_run = [
38+
"check_pymc_params_match_rv_op",
39+
"check_rv_size",
40+
]
41+
42+
def test_random_matches_poisson(self):
43+
discrete_random_tester(
44+
dist=self.pymc_dist,
45+
paramdomains={"mu": Rplus, "lam": Domain([0], edges=(None, None))},
46+
ref_rand=lambda mu, lam, size: scipy.stats.poisson.rvs(mu, size=size),
47+
)
48+
49+
@pytest.mark.parametrize("mu", (2.5, 20, 50))
50+
def test_random_lam_expected_moments(self, mu):
51+
lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
52+
dist = self.pymc_dist.dist(mu=mu, lam=lam, size=(10_000, len(lam)))
53+
draws = dist.eval()
54+
55+
expected_mean = mu / (1 - lam)
56+
np.testing.assert_allclose(draws.mean(0), expected_mean, rtol=1e-1)
57+
58+
expected_std = np.sqrt(mu / (1 - lam) ** 3)
59+
np.testing.assert_allclose(draws.std(0), expected_std, rtol=1e-1)
60+
61+
def test_logp_matches_poisson(self):
62+
# We are only checking this distribution for lambda=0 where it's equivalent to Poisson.
63+
mu = pt.scalar("mu")
64+
lam = pt.scalar("lam")
65+
value = pt.vector("value")
66+
67+
logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
68+
logp_fn = pytensor.function([value, mu, lam], logp)
69+
70+
test_value = np.array([0, 1, 2, 30])
71+
for test_mu in (0.01, 0.1, 0.9, 1, 1.5, 20, 100):
72+
np.testing.assert_allclose(
73+
logp_fn(test_value, test_mu, lam=0),
74+
scipy.stats.poisson.logpmf(test_value, test_mu),
75+
)
76+
77+
# Check out-of-bounds values
78+
value = pt.scalar("value")
79+
logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
80+
logp_fn = pytensor.function([value, mu, lam], logp)
81+
82+
logp_fn(-1, mu=5, lam=0) == -np.inf
83+
logp_fn(9, mu=5, lam=-1) == -np.inf
84+
85+
# Check mu/lam restrictions
86+
with pytest.raises(ParameterValueError):
87+
logp_fn(1, mu=1, lam=2)
88+
89+
with pytest.raises(ParameterValueError):
90+
logp_fn(1, mu=0, lam=0)
91+
92+
with pytest.raises(ParameterValueError):
93+
logp_fn(1, mu=1, lam=-1)
94+
95+
def test_logp_lam_expected_moments(self):
96+
mu = 30
97+
lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
98+
with pm.Model():
99+
x = GeneralizedPoisson("x", mu=mu, lam=lam)
100+
trace = pm.sample(chains=1, draws=10_000, random_seed=96).posterior
101+
102+
expected_mean = mu / (1 - lam)
103+
np.testing.assert_allclose(trace["x"].mean(("chain", "draw")), expected_mean, rtol=1e-1)
104+
105+
expected_std = np.sqrt(mu / (1 - lam) ** 3)
106+
np.testing.assert_allclose(trace["x"].std(("chain", "draw")), expected_std, rtol=1e-1)
107+
108+
@pytest.mark.parametrize(
109+
"mu, lam, size, expected",
110+
[
111+
(50, [-0.6, 0, 0.6], None, np.floor(50 / (1 - np.array([-0.6, 0, 0.6])))),
112+
([5, 50], -0.1, (4, 2), np.full((4, 2), np.floor(np.array([5, 50]) / 1.1))),
113+
],
114+
)
115+
def test_moment(self, mu, lam, size, expected):
116+
with pm.Model() as model:
117+
GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
118+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)