Skip to content

Commit c21da86

Browse files
committed
Added moment test
1 parent cad237e commit c21da86

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

pymc_experimental/tests/test_distributions.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import scipy.stats.distributions as ssd
1717

1818
# test support imports from pymc
19-
from pymc.test.test_distributions import (
19+
from pymc.tests.test_distributions import (
2020
R,
2121
Rplus,
2222
Domain,
@@ -40,15 +40,11 @@ def test_genextreme(self):
4040
GenExtreme,
4141
R,
4242
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
43-
lambda value, mu, sigma, xi: ssd.genextreme.logpdf(
44-
value, c=-xi, loc=mu, scale=sigma
45-
),
43+
lambda value, mu, sigma, xi: ssd.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma),
4644
)
4745
self.check_logcdf(
4846
GenExtreme,
4947
R,
5048
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
51-
lambda value, mu, sigma, xi: ssd.genextreme.logcdf(
52-
value, c=-xi, loc=mu, scale=sigma
53-
),
49+
lambda value, mu, sigma, xi: ssd.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma),
5450
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# import aesara
2+
import numpy as np
3+
import pytest
4+
import scipy.stats as st
5+
6+
# from aesara import tensor as at
7+
# from scipy import special
8+
9+
from pymc_experimental.distributions import (
10+
GenExtreme,
11+
)
12+
13+
from pymc.model import Model
14+
15+
from pymc.tests.test_distributions_moments import assert_moment_is_expected
16+
17+
18+
@pytest.mark.parametrize(
19+
"mu, sigma, xi, size, expected",
20+
[
21+
(0, 1, 0, None, 0),
22+
(1, np.arange(1, 4), 0.1, None, np.arange(1, 4) * (1.1 ** -0.1 - 1) / 0.1),
23+
(np.arange(5), 1, 0.1, None, np.arange(5) + (1.1 ** -0.1 - 1) / 0.1),
24+
(
25+
0,
26+
1,
27+
np.linspace(-0.2, 0.2, 6),
28+
None,
29+
((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
30+
/ np.linspace(-0.2, 0.2, 6),
31+
),
32+
(1, 2, 0.1, 5, np.full(5, 1 + 2 * (1.1 ** -0.1 - 1) / 0.1)),
33+
(
34+
np.arange(6),
35+
np.arange(1, 7),
36+
np.linspace(-0.2, 0.2, 6),
37+
(3, 6),
38+
np.full(
39+
(3, 6),
40+
np.arange(6)
41+
+ np.arange(1, 7)
42+
* ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
43+
/ np.linspace(-0.2, 0.2, 6),
44+
),
45+
),
46+
],
47+
)
48+
def test_genextreme_moment(mu, sigma, xi, size, expected):
49+
with Model() as model:
50+
GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size)
51+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)