|
8 | 8 |
|
9 | 9 | import pymc as pm
|
10 | 10 |
|
11 |
| -from pymc import Simulator |
12 | 11 | from pymc.distributions import (
|
13 | 12 | AsymmetricLaplace,
|
14 | 13 | Bernoulli,
|
|
50 | 49 | Poisson,
|
51 | 50 | PolyaGamma,
|
52 | 51 | Rice,
|
| 52 | + Simulator, |
53 | 53 | SkewNormal,
|
54 | 54 | StudentT,
|
55 | 55 | Triangular,
|
|
62 | 62 | ZeroInflatedNegativeBinomial,
|
63 | 63 | ZeroInflatedPoisson,
|
64 | 64 | )
|
65 |
| -from pymc.distributions.distribution import get_moment |
| 65 | +from pymc.distributions.distribution import _get_moment, get_moment |
66 | 66 | from pymc.distributions.multivariate import MvNormal
|
67 | 67 | from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
|
68 | 68 | from pymc.initial_point import make_initial_point_fn
|
69 | 69 | from pymc.model import Model
|
70 | 70 |
|
71 | 71 |
|
| 72 | +def test_all_distributions_have_moments(): |
| 73 | + import pymc.distributions as dist_module |
| 74 | + |
| 75 | + from pymc.distributions.distribution import DistributionMeta |
| 76 | + |
| 77 | + dists = (getattr(dist_module, dist) for dist in dist_module.__all__) |
| 78 | + dists = (dist for dist in dists if isinstance(dist, DistributionMeta)) |
| 79 | + missing_moments = { |
| 80 | + dist for dist in dists if type(getattr(dist, "rv_op", None)) not in _get_moment.registry |
| 81 | + } |
| 82 | + |
| 83 | + # Ignore super classes |
| 84 | + missing_moments -= { |
| 85 | + dist_module.Distribution, |
| 86 | + dist_module.Discrete, |
| 87 | + dist_module.Continuous, |
| 88 | + dist_module.NoDistribution, |
| 89 | + dist_module.DensityDist, |
| 90 | + dist_module.simulator.Simulator, |
| 91 | + } |
| 92 | + |
| 93 | + # Distributions that have not been refactored for V4 yet |
| 94 | + not_implemented = { |
| 95 | + dist_module.multivariate.LKJCorr, |
| 96 | + dist_module.mixture.Mixture, |
| 97 | + dist_module.mixture.MixtureSameFamily, |
| 98 | + dist_module.mixture.NormalMixture, |
| 99 | + dist_module.timeseries.AR, |
| 100 | + dist_module.timeseries.AR1, |
| 101 | + dist_module.timeseries.GARCH11, |
| 102 | + dist_module.timeseries.GaussianRandomWalk, |
| 103 | + dist_module.timeseries.MvGaussianRandomWalk, |
| 104 | + dist_module.timeseries.MvStudentTRandomWalk, |
| 105 | + } |
| 106 | + |
| 107 | + # Distributions that have been refactored but don't yet have moments |
| 108 | + not_implemented |= { |
| 109 | + dist_module.discrete.DiscreteWeibull, |
| 110 | + dist_module.multivariate.CAR, |
| 111 | + dist_module.multivariate.DirichletMultinomial, |
| 112 | + dist_module.multivariate.KroneckerNormal, |
| 113 | + dist_module.multivariate.Wishart, |
| 114 | + } |
| 115 | + |
| 116 | + unexpected_implemented = not_implemented - missing_moments |
| 117 | + if unexpected_implemented: |
| 118 | + raise Exception( |
| 119 | + f"Distributions {unexpected_implemented} have a `get_moment` implemented. " |
| 120 | + "This test must be updated to expect this." |
| 121 | + ) |
| 122 | + |
| 123 | + unexpected_not_implemented = missing_moments - not_implemented |
| 124 | + if unexpected_not_implemented: |
| 125 | + raise NotImplementedError( |
| 126 | + f"Unexpected by this test, distributions {unexpected_not_implemented} do " |
| 127 | + "not have a `get_moment` implementation. Either add a moment or filter " |
| 128 | + "these distributions in this test." |
| 129 | + ) |
| 130 | + |
| 131 | + |
72 | 132 | def test_rv_size_is_none():
|
73 | 133 | rv = Normal.dist(0, 1, size=None)
|
74 | 134 | assert rv_size_is_none(rv.owner.inputs[1])
|
|
0 commit comments