Skip to content

Commit 78b7f2f

Browse files
ccapraniricardoV94
andauthored
Add Genextreme distribution (#84)
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent f560e1e commit 78b7f2f

File tree

5 files changed

+383
-4
lines changed

5 files changed

+383
-4
lines changed
Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,24 @@
1-
from pymc_experimental.distributions import histogram_utils
2-
from pymc_experimental.distributions.histogram_utils import histogram_approximation
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# coding: utf-8
16+
"""
17+
Experimental probability distributions for stochastic nodes in PyMC.
18+
"""
19+
20+
from pymc_experimental.distributions.continuous import GenExtreme
21+
22+
__all__ = [
23+
"GenExtreme",
24+
]
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# coding: utf-8
16+
"""
17+
Experimental probability distributions for stochastic nodes in PyMC.
18+
19+
The imports from pymc are not fully replicated here: add imports as necessary.
20+
"""
21+
22+
from typing import List, Tuple, Union
23+
24+
import aesara.tensor as at
25+
import numpy as np
26+
from aesara.tensor.random.op import RandomVariable
27+
from aesara.tensor.var import TensorVariable
28+
from pymc.aesaraf import floatX
29+
from pymc.distributions.dist_math import check_parameters
30+
from pymc.distributions.distribution import Continuous
31+
from pymc.distributions.shape_utils import rv_size_is_none
32+
from scipy import stats
33+
34+
35+
class GenExtremeRV(RandomVariable):
36+
name: str = "Generalized Extreme Value"
37+
ndim_supp: int = 0
38+
ndims_params: List[int] = [0, 0, 0]
39+
dtype: str = "floatX"
40+
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
41+
42+
def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
43+
return super().__call__(mu, sigma, xi, size=size, **kwargs)
44+
45+
@classmethod
46+
def rng_fn(
47+
cls,
48+
rng: Union[np.random.RandomState, np.random.Generator],
49+
mu: np.ndarray,
50+
sigma: np.ndarray,
51+
xi: np.ndarray,
52+
size: Tuple[int, ...],
53+
) -> np.ndarray:
54+
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization
55+
return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)
56+
57+
58+
gev = GenExtremeRV()
59+
60+
61+
class GenExtreme(Continuous):
62+
r"""
63+
Univariate Generalized Extreme Value log-likelihood
64+
65+
The cdf of this distribution is
66+
67+
.. math::
68+
69+
G(x \mid \mu, \sigma, \xi) = \exp\left[ -\left(1 + \xi z\right)^{-\frac{1}{\xi}} \right]
70+
71+
where
72+
73+
.. math::
74+
75+
z = \frac{x - \mu}{\sigma}
76+
77+
and is defined on the set:
78+
79+
.. math::
80+
81+
\left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
82+
83+
Note that this parametrization is per Coles (2001), and differs from that of
84+
Scipy in the sign of the shape parameter, :math:`\xi`.
85+
86+
.. plot::
87+
88+
import matplotlib.pyplot as plt
89+
import numpy as np
90+
import scipy.stats as st
91+
import arviz as az
92+
plt.style.use('arviz-darkgrid')
93+
x = np.linspace(-10, 20, 200)
94+
mus = [0., 4., -1.]
95+
sigmas = [2., 2., 4.]
96+
xis = [-0.3, 0.0, 0.3]
97+
for mu, sigma, xi in zip(mus, sigmas, xis):
98+
pdf = st.genextreme.pdf(x, c=-xi, loc=mu, scale=sigma)
99+
plt.plot(x, pdf, label=rf'$\mu$ = {mu}, $\sigma$ = {sigma}, $\xi$={xi}')
100+
plt.xlabel('x', fontsize=12)
101+
plt.ylabel('f(x)', fontsize=12)
102+
plt.legend(loc=1)
103+
plt.show()
104+
105+
106+
======== =========================================================================
107+
Support * :math:`x \in [\mu - \sigma/\xi, +\infty]`, when :math:`\xi > 0`
108+
* :math:`x \in \mathbb{R}` when :math:`\xi = 0`
109+
* :math:`x \in [-\infty, \mu - \sigma/\xi]`, when :math:`\xi < 0`
110+
Mean * :math:`\mu + \sigma(g_1 - 1)/\xi`, when :math:`\xi \neq 0, \xi < 1`
111+
* :math:`\mu + \sigma \gamma`, when :math:`\xi = 0`
112+
* :math:`\infty`, when :math:`\xi \geq 1`
113+
where :math:`\gamma` is the Euler-Mascheroni constant, and
114+
:math:`g_k = \Gamma (1-k\xi)`
115+
Variance * :math:`\sigma^2 (g_2 - g_1^2)/\xi^2`, when :math:`\xi \neq 0, \xi < 0.5`
116+
* :math:`\frac{\pi^2}{6} \sigma^2`, when :math:`\xi = 0`
117+
* :math:`\infty`, when :math:`\xi \geq 0.5`
118+
======== =========================================================================
119+
120+
Parameters
121+
----------
122+
mu: float
123+
Location parameter.
124+
sigma: float
125+
Scale parameter (sigma > 0).
126+
xi: float
127+
Shape parameter
128+
scipy: bool
129+
Whether or not to use the Scipy interpretation of the shape parameter
130+
(defaults to `False`).
131+
132+
References
133+
----------
134+
.. [Coles2001] Coles, S.G. (2001).
135+
An Introduction to the Statistical Modeling of Extreme Values
136+
Springer-Verlag, London
137+
138+
"""
139+
140+
rv_op = gev
141+
142+
@classmethod
143+
def dist(cls, mu=0, sigma=1, xi=0, scipy=False, **kwargs):
144+
# If SciPy, use its parametrization, otherwise convert to standard
145+
if scipy:
146+
xi = -xi
147+
mu = at.as_tensor_variable(floatX(mu))
148+
sigma = at.as_tensor_variable(floatX(sigma))
149+
xi = at.as_tensor_variable(floatX(xi))
150+
151+
return super().dist([mu, sigma, xi], **kwargs)
152+
153+
def logp(value, mu, sigma, xi):
154+
"""
155+
Calculate log-probability of Generalized Extreme Value distribution
156+
at specified value.
157+
158+
Parameters
159+
----------
160+
value: numeric
161+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
162+
values are desired the values must be provided in a numpy array or Aesara tensor
163+
164+
Returns
165+
-------
166+
TensorVariable
167+
"""
168+
scaled = (value - mu) / sigma
169+
170+
logp_expression = at.switch(
171+
at.isclose(xi, 0),
172+
-at.log(sigma) - scaled - at.exp(-scaled),
173+
-at.log(sigma)
174+
- ((xi + 1) / xi) * at.log1p(xi * scaled)
175+
- at.pow(1 + xi * scaled, -1 / xi),
176+
)
177+
178+
logp = at.switch(at.gt(1 + xi * scaled, 0.0), logp_expression, -np.inf)
179+
180+
return check_parameters(
181+
logp, sigma > 0, at.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
182+
)
183+
184+
def logcdf(value, mu, sigma, xi):
185+
"""
186+
Compute the log of the cumulative distribution function for Generalized Extreme Value
187+
distribution at the specified value.
188+
189+
Parameters
190+
----------
191+
value: numeric or np.ndarray or `TensorVariable`
192+
Value(s) for which log CDF is calculated. If the log CDF for
193+
multiple values are desired the values must be provided in a numpy
194+
array or `TensorVariable`.
195+
196+
Returns
197+
-------
198+
TensorVariable
199+
"""
200+
scaled = (value - mu) / sigma
201+
logc_expression = at.switch(
202+
at.isclose(xi, 0), -at.exp(-scaled), -at.pow(1 + xi * scaled, -1 / xi)
203+
)
204+
205+
logc = at.switch(1 + xi * (value - mu) / sigma > 0, logc_expression, -np.inf)
206+
207+
return check_parameters(
208+
logc, sigma > 0, at.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
209+
)
210+
211+
def moment(rv, size, mu, sigma, xi):
212+
r"""
213+
Using the mode, as the mean can be infinite when :math:`\xi > 1`
214+
"""
215+
mode = at.switch(at.isclose(xi, 0), mu, mu + sigma * (at.pow(1 + xi, -xi) - 1) / xi)
216+
if not rv_size_is_none(size):
217+
mode = at.full(size, mode)
218+
return mode
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from pymc_experimental.distributions import histogram_utils
2+
from pymc_experimental.distributions.histogram_utils import histogram_approximation
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# general imports
16+
import aesara
17+
import numpy as np
18+
import pymc as pm
19+
import pytest
20+
import scipy.stats.distributions as sp
21+
22+
# test support imports from pymc
23+
from pymc.tests.distributions.util import (
24+
BaseTestDistributionRandom,
25+
Domain,
26+
R,
27+
Rplusbig,
28+
assert_moment_is_expected,
29+
check_logcdf,
30+
check_logp,
31+
seeded_scipy_distribution_builder,
32+
)
33+
from pymc.tests.helpers import select_by_precision
34+
35+
# the distributions to be tested
36+
from pymc_experimental.distributions import GenExtreme
37+
38+
39+
class TestGenExtremeClass:
40+
"""
41+
Wrapper class so that tests of experimental additions can be dropped into
42+
PyMC directly on adoption.
43+
44+
pm.logp(GenExtreme.dist(mu=0.,sigma=1.,xi=0.5),value=-0.01)
45+
"""
46+
47+
@pytest.mark.xfail(
48+
condition=(aesara.config.floatX == "float32"),
49+
reason="PyMC underflows earlier than scipy on float32",
50+
)
51+
def test_logp(self):
52+
check_logp(
53+
GenExtreme,
54+
R,
55+
{
56+
"mu": R,
57+
"sigma": Rplusbig,
58+
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
59+
},
60+
lambda value, mu, sigma, xi: sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
61+
if 1 + xi * (value - mu) / sigma > 0
62+
else -np.inf,
63+
)
64+
65+
if aesara.config.floatX == "float32":
66+
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
67+
68+
def test_logcdf(self):
69+
check_logcdf(
70+
GenExtreme,
71+
R,
72+
{
73+
"mu": R,
74+
"sigma": Rplusbig,
75+
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
76+
},
77+
lambda value, mu, sigma, xi: sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
78+
if 1 + xi * (value - mu) / sigma > 0
79+
else -np.inf,
80+
decimal=select_by_precision(float64=6, float32=2),
81+
)
82+
83+
@pytest.mark.parametrize(
84+
"mu, sigma, xi, size, expected",
85+
[
86+
(0, 1, 0, None, 0),
87+
(1, np.arange(1, 4), 0.1, None, 1 + np.arange(1, 4) * (1.1**-0.1 - 1) / 0.1),
88+
(np.arange(5), 1, 0.1, None, np.arange(5) + (1.1**-0.1 - 1) / 0.1),
89+
(
90+
0,
91+
1,
92+
np.linspace(-0.2, 0.2, 6),
93+
None,
94+
((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
95+
/ np.linspace(-0.2, 0.2, 6),
96+
),
97+
(1, 2, 0.1, 5, np.full(5, 1 + 2 * (1.1**-0.1 - 1) / 0.1)),
98+
(
99+
np.arange(6),
100+
np.arange(1, 7),
101+
np.linspace(-0.2, 0.2, 6),
102+
(3, 6),
103+
np.full(
104+
(3, 6),
105+
np.arange(6)
106+
+ np.arange(1, 7)
107+
* ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
108+
/ np.linspace(-0.2, 0.2, 6),
109+
),
110+
),
111+
],
112+
)
113+
def test_genextreme_moment(self, mu, sigma, xi, size, expected):
114+
with pm.Model() as model:
115+
GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size)
116+
assert_moment_is_expected(model, expected)
117+
118+
def test_gen_extreme_scipy_kwarg(self):
119+
dist = GenExtreme.dist(xi=1, scipy=False)
120+
assert dist.owner.inputs[-1].eval() == 1
121+
122+
dist = GenExtreme.dist(xi=1, scipy=True)
123+
assert dist.owner.inputs[-1].eval() == -1
124+
125+
126+
class TestGenExtreme(BaseTestDistributionRandom):
127+
pymc_dist = GenExtreme
128+
pymc_dist_params = {"mu": 0, "sigma": 1, "xi": -0.1}
129+
expected_rv_op_params = {"mu": 0, "sigma": 1, "xi": -0.1}
130+
# Notice, using different parametrization of xi sign to scipy
131+
reference_dist_params = {"loc": 0, "scale": 1, "c": 0.1}
132+
reference_dist = seeded_scipy_distribution_builder("genextreme")
133+
tests_to_run = [
134+
"check_pymc_params_match_rv_op",
135+
"check_pymc_draws_match_reference",
136+
"check_rv_size",
137+
]

0 commit comments

Comments
 (0)