Skip to content

Commit 1232d0b

Browse files
committed
Change scipy.ststats.distribution import
1 parent 7356e3f commit 1232d0b

File tree

5 files changed

+51
-271
lines changed

5 files changed

+51
-271
lines changed

pymc_experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@
99
if len(_log.handlers) == 0:
1010
handler = logging.StreamHandler()
1111
_log.addHandler(handler)
12+
13+
14+
from pymc_experimental.distributions import *
15+
from pymc_experimental.tests import test

pymc_experimental/distributions/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
Experimental probability distributions for stochastic nodes in PyMC.
1818
"""
1919

20-
from pymc-experimental.distributions.genextreme import GenExtreme
20+
from pymc_experimental.distributions.continuous import (
21+
GenExtreme,
22+
)
2123

2224
__all__ = [
2325
"GenExtreme",

pymc_experimental/distributions/genextreme.py

Lines changed: 0 additions & 217 deletions
This file was deleted.

pymc_experimental/tests/test_distributions.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,44 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import functools
15-
import itertools
16-
import sys
1714

18-
import aesara
19-
import aesara.tensor as at
20-
import numpy as np
21-
import numpy.random as nr
15+
# general imports
16+
import scipy.stats.distributions as ssd
2217

23-
import pytest
24-
import scipy.stats
25-
import scipy.stats.distributions as sp
26-
27-
from aesara.compile.mode import Mode
28-
from aesara.graph.basic import ancestors
29-
from aesara.tensor.random.op import RandomVariable
30-
from aesara.tensor.var import TensorVariable
31-
from numpy import array, inf, log
32-
from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
33-
from scipy import integrate
34-
from scipy.special import erf, logit
35-
36-
import pymc as pm
18+
# test support imports from pymc
19+
from pymc.test.test_distributions import (
20+
R,
21+
Rplus,
22+
Domain,
23+
TestMatchesScipy,
24+
)
3725

38-
from pymc.aesaraf import floatX, intX
39-
from pymc-experimental.distributions import (
26+
# the distributions to be tested
27+
from pymc_experimental.distributions import (
4028
GenExtreme,
4129
)
42-
from pymc.distributions.shape_utils import to_tuple
43-
from pymc.math import kronecker
44-
from pymc.model import Deterministic, Model, Point, Potential
45-
from pymc.tests.helpers import select_by_precision
46-
from pymc.vartypes import continuous_types
4730

48-
def test_genextreme(self):
49-
self.check_logp(
50-
GenExtreme,
51-
R,
52-
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
53-
lambda value, mu, sigma, xi: sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma),
54-
)
55-
self.check_logcdf(
56-
GenExtreme,
57-
R,
58-
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
59-
lambda value, mu, sigma, xi: sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma),
60-
)
61-
31+
32+
class TestMatchesScipyX(TestMatchesScipy):
33+
"""
34+
Wrapper class so that tests of experimental additions can be dropped into
35+
PyMC directly on adoption.
36+
"""
37+
38+
def test_genextreme(self):
39+
self.check_logp(
40+
GenExtreme,
41+
R,
42+
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
43+
lambda value, mu, sigma, xi: ssd.genextreme.logpdf(
44+
value, c=-xi, loc=mu, scale=sigma
45+
),
46+
)
47+
self.check_logcdf(
48+
GenExtreme,
49+
R,
50+
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -1, -0.5, 0, 0.5, 1, 1])},
51+
lambda value, mu, sigma, xi: ssd.genextreme.logcdf(
52+
value, c=-xi, loc=mu, scale=sigma
53+
),
54+
)

pymc_experimental/tests/test_distributions_random.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import functools
15-
import itertools
1614

17-
from typing import Callable, List, Optional
15+
# general imports
1816

19-
import aesara
20-
import numpy as np
21-
import numpy.random as nr
22-
import numpy.testing as npt
23-
import pytest
24-
import scipy.stats as st
17+
# test support imports from pymc
18+
from pymc.tests.test_distributions_random import (
19+
BaseTestDistribution,
20+
seeded_scipy_distribution_builder,
21+
)
2522

26-
from numpy.testing import assert_almost_equal, assert_array_almost_equal
23+
# the distributions to be tested
24+
import pymc_experimental as pmx
2725

2826

2927
class TestGenExtreme(BaseTestDistribution):
30-
pymc_dist = pm.GenExtreme
28+
pymc_dist = pmx.GenExtreme
3129
pymc_dist_params = {"mu": 0, "sigma": 1, "xi": -0.1}
3230
expected_rv_op_params = {"mu": 0, "sigma": 1, "xi": -0.1}
3331
# Notice, using different parametrization of xi sign to scipy

0 commit comments

Comments
 (0)